Skip to content

Commit

Permalink
Simplify grid getter and add tests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
cscjlan committed Oct 18, 2024
1 parent c642c08 commit c8d3e35
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 114 deletions.
156 changes: 42 additions & 114 deletions src/grid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <iomanip>
#include <ios>
#include <iostream>
#include <limits>
#include <mpi.h>
#include <numeric>
#include <sstream>
Expand Down Expand Up @@ -264,7 +265,7 @@ template <typename T, int32_t stencil> class FsGrid {
* \param MPI_Comm The MPI communicator this grid should use.
* \param periodic An array specifying, for each dimension, whether it is to be treated as periodic.
*/
FsGrid(std::array<FsSize_t, 3> globalSize, MPI_Comm parentComm, std::array<bool, 3> periodic,
FsGrid(const std::array<FsSize_t, 3>& globalSize, MPI_Comm parentComm, const std::array<bool, 3>& periodic,
const std::array<double, 3>& physicalGridSpacing, const std::array<double, 3>& physicalGlobalStart,
const std::array<Task_t, 3>& decomposition = {0, 0, 0})
: comm3d(fsgrid_detail::createCartesianCommunicator(
Expand All @@ -281,14 +282,14 @@ template <typename T, int32_t stencil> class FsGrid {
fsgrid_detail::getTaskPosition(comm3d), coordinates.numTasksPerDim, periodic, comm3d, rank)),
neighbourRankToIndex(fsgrid_detail::mapNeighbourRankToIndex(
neighbourIndexToRank, fsgrid_detail::getFSCommSize(fsgrid_detail::getCommSize(parentComm)))),
data(rank == -1
? 0ul
: static_cast<size_t>(std::accumulate(coordinates.storageSize.cbegin(),
coordinates.storageSize.cend(), 1, std::multiplies<>()))),
neighbourSendType(
fsgrid_detail::generateMPITypes<T>(coordinates.storageSize, coordinates.localSize, stencil, true)),
neighbourReceiveType(
fsgrid_detail::generateMPITypes<T>(coordinates.storageSize, coordinates.localSize, stencil, false)) {}
fsgrid_detail::generateMPITypes<T>(coordinates.storageSize, coordinates.localSize, stencil, false)),
data(rank == -1
? 0ul
: static_cast<size_t>(std::accumulate(coordinates.storageSize.cbegin(),
coordinates.storageSize.cend(), 1, std::multiplies<>()))) {}

/*! Finalize instead of destructor, as the MPI calls fail after the main program called MPI_Finalize().
* Cleans up the cartesian communicator
Expand Down Expand Up @@ -322,120 +323,18 @@ template <typename T, int32_t stencil> class FsGrid {
bool localIdInBounds(LocalID id) const { return 0 <= id && (size_t)id < data.size(); }

T* get(LocalID id) {
if (!localIdInBounds(id)) {
std::cerr << "Out-of-bounds access in FsGrid::get!" << std::endl
<< "(LocalID = " << id << ", but storage space is " << data.size() << ". Expect weirdness."
<< std::endl;
return NULL;
}
return &data[id];
FSGRID_DEBUG_ASSERT(localIdInBounds(id), "Out-of bounds access in FsGrid::get!", "(LocalID = ", id,
", but storage space is ", data.size(), ". Expect weirdness.");
return localIdInBounds(id) ? &data[static_cast<size_t>(id)] : nullptr;
}

// TODO: test
/*! Get a reference to the field data in a cell
* \param x x-Coordinate, in cells
* \param y y-Coordinate, in cells
* \param z z-Coordinate, in cells
* \return A reference to cell data in the given cell
*/
T* get(FsIndex_t x, FsIndex_t y, FsIndex_t z) {

// Keep track which neighbour this cell actually belongs to (13 = ourself)
int32_t isInNeighbourDomain = 13;
int32_t coord_shift[3] = {0, 0, 0};
if (x < 0) {
isInNeighbourDomain -= 9;
coord_shift[0] = 1;
}
if (x >= coordinates.localSize[0]) {
isInNeighbourDomain += 9;
coord_shift[0] = -1;
}
if (y < 0) {
isInNeighbourDomain -= 3;
coord_shift[1] = 1;
}
if (y >= coordinates.localSize[1]) {
isInNeighbourDomain += 3;
coord_shift[1] = -1;
}
if (z < 0) {
isInNeighbourDomain -= 1;
coord_shift[2] = 1;
}
if (z >= coordinates.localSize[2]) {
isInNeighbourDomain += 1;
coord_shift[2] = -1;
}

// Santiy-Check that the requested cell is actually inside our domain
// TODO: ugh, this is ugly.
#ifdef FSGRID_DEBUG
bool inside = true;
if (localSize[0] <= 1 && !periodic[0]) {
if (x != 0) {
std::cerr << "x != 0 despite non-periodic x-axis with only one cell." << std::endl;
inside = false;
}
} else {
if (x < -stencil || x >= localSize[0] + stencil) {
std::cerr << "x = " << x << " is outside of [ " << -stencil << ", " << localSize[0] + stencil << "[!"
<< std::endl;
inside = false;
}
}

if (localSize[1] <= 1 && !periodic[1]) {
if (y != 0) {
std::cerr << "y != 0 despite non-periodic y-axis with only one cell." << std::endl;
inside = false;
}
} else {
if (y < -stencil || y >= localSize[1] + stencil) {
std::cerr << "y = " << y << " is outside of [ " << -stencil << ", " << localSize[1] + stencil << "[!"
<< std::endl;
inside = false;
}
}

if (localSize[2] <= 1 && !periodic[2]) {
if (z != 0) {
std::cerr << "z != 0 despite non-periodic z-axis with only one cell." << std::endl;
inside = false;
}
} else {
if (z < -stencil || z >= localSize[2] + stencil) {
inside = false;
std::cerr << "z = " << z << " is outside of [ " << -stencil << ", " << localSize[2] + stencil << "[!"
<< std::endl;
}
}
if (!inside) {
std::cerr << "Out-of bounds access in FsGrid::get! Expect weirdness." << std::endl;
return NULL;
}
#endif // FSGRID_DEBUG

if (isInNeighbourDomain != 13) {

// Check if the corresponding neighbour exists
if (neighbourIndexToRank[isInNeighbourDomain] == MPI_PROC_NULL) {
// Neighbour doesn't exist, we must be an outer boundary cell
// (or something is quite wrong)
return NULL;

} else if (neighbourIndexToRank[isInNeighbourDomain] == rank) {
// For periodic boundaries, where the neighbour is actually ourself,
// return our own actual cell instead of the ghost
x += coord_shift[0] * coordinates.localSize[0];
y += coord_shift[1] * coordinates.localSize[1];
z += coord_shift[2] * coordinates.localSize[2];
}
// Otherwise we return the ghost cell
}

return get(localIDFromLocalCoordinates(x, y, z));
}
T* get(FsIndex_t x, FsIndex_t y, FsIndex_t z) { return get(localIDFromCellCoordinates(x, y, z)); }

// ============================
// Coordinate change functions
Expand All @@ -459,6 +358,34 @@ template <typename T, int32_t stencil> class FsGrid {
return coordinates.physicalToFractionalGlobal(args...);
}

/*! Compute the local id from cell coordinates (these include ghost cells)
* \param x x-Coordinate, in cells
* \param y y-Coordinate, in cells
* \param z z-Coordinate, in cells
* \return local id of the cell
*/
LocalID localIDFromCellCoordinates(FsIndex_t x, FsIndex_t y, FsIndex_t z) const {
FSGRID_DEBUG_ASSERT(coordinates.cellIndicesAreWithinBounds(x, y, z), "Out-of bounds access in FsGrid::get!");
const auto neighbourIndex = coordinates.neighbourIndexFromCellCoordinates(x, y, z);
const auto neighbourRank = neighbourIndexToRank[neighbourIndex];
const auto isSelf = neighbourRank == rank;

FSGRID_DEBUG_ASSERT(isSelf || neighbourRank != MPI_PROC_NULL,
"Trying to access data from a non-existing neighbour");

const auto neighbourIsSelf = neighbourIndex != 13 && isSelf;
const auto id = neighbourIsSelf ? coordinates.localIDFromLocalCoordinates(coordinates.shiftCellIndices(x, y, z))
: coordinates.localIDFromLocalCoordinates(x, y, z);

return coordinates.cellIndicesAreWithinBounds(x, y, z) && (isSelf || neighbourRank != MPI_PROC_NULL)
? id
: std::numeric_limits<LocalID>::min();
}

LocalID localIDFromCellCoordinates(const std::array<FsIndex_t, 3>& indices) const {
return localIDFromCellCoordinates(indices[0], indices[1], indices[2]);
}

// ============================
// Getters
// ============================
Expand Down Expand Up @@ -967,10 +894,11 @@ template <typename T, int32_t stencil> class FsGrid {
};
//!< Lookup table from rank to index in the neighbour array
const std::vector<char> neighbourRankToIndex = {};
//! Actual storage of field data
std::vector<T> data = {};
//!< Datatype for sending data
std::array<MPI_Datatype, 27> neighbourSendType = {};
//!< Datatype for receiving data
std::array<MPI_Datatype, 27> neighbourReceiveType = {};

//! Actual storage of field data
std::vector<T> data = {};
};
78 changes: 78 additions & 0 deletions tests/mpi_tests/grid_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,81 @@ TEST(FsGridTest, localIdInBounds) {
}
}
}

TEST(FsGridTest, getNonPeriodic) {
const std::array<FsGridTools::FsSize_t, 3> globalSize{12, 6, 2048};
const MPI_Comm parentComm = MPI_COMM_WORLD;
const std::array<bool, 3> periodic{false, false, false};
constexpr int32_t numGhostCells = 1;

auto grid =
FsGrid<std::array<double, 8>, numGhostCells>(globalSize, parentComm, periodic, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0});
const auto localSize = grid.getLocalSize();
for (int32_t x = 0; x < localSize[0]; x++) {
for (int32_t y = 0; y < localSize[1]; y++) {
for (int32_t z = 0; z < localSize[2]; z++) {
ASSERT_NE(grid.get(x, y, z), nullptr);
}
}
}

ASSERT_EQ(grid.get(-numGhostCells, 0, 0), nullptr);
ASSERT_EQ(grid.get(-numGhostCells - 1, 0, 0), nullptr);
ASSERT_EQ(grid.get(grid.getLocalSize()[0] + numGhostCells, 0, 0), nullptr);
ASSERT_EQ(grid.get(grid.getLocalSize()[0] + numGhostCells - 1, 0, 0), nullptr);

ASSERT_EQ(grid.get(0, -numGhostCells, 0), nullptr);
ASSERT_EQ(grid.get(0, -numGhostCells - 1, 0), nullptr);
ASSERT_EQ(grid.get(0, grid.getLocalSize()[1] + numGhostCells, 0), nullptr);
ASSERT_EQ(grid.get(0, grid.getLocalSize()[1] + numGhostCells - 1, 0), nullptr);

// This depends on the position on the grid
if (grid.getLocalStart()[2] == 0) {
ASSERT_EQ(grid.get(0, 0, -numGhostCells), nullptr);
} else {
ASSERT_NE(grid.get(0, 0, -numGhostCells), nullptr);
}

ASSERT_EQ(grid.get(0, 0, -numGhostCells - 1), nullptr);

// This depends on the position on the grid
if (grid.getLocalStart()[2] + grid.getLocalSize()[2] == static_cast<FsGridTools::FsIndex_t>(globalSize[2])) {
ASSERT_EQ(grid.get(0, 0, grid.getLocalSize()[2] + numGhostCells - 1), nullptr);
} else {
ASSERT_NE(grid.get(0, 0, grid.getLocalSize()[2] + numGhostCells - 1), nullptr);
}
ASSERT_EQ(grid.get(0, 0, grid.getLocalSize()[2] + numGhostCells), nullptr);
}

TEST(FsGridTest, getPeriodic) {
const std::array<FsGridTools::FsSize_t, 3> globalSize{120, 5, 1048};
const MPI_Comm parentComm = MPI_COMM_WORLD;
const std::array<bool, 3> periodic{true, true, true};
constexpr int32_t numGhostCells = 2;

auto grid =
FsGrid<std::array<double, 8>, numGhostCells>(globalSize, parentComm, periodic, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0});
const auto localSize = grid.getLocalSize();
for (int32_t x = 0; x < localSize[0]; x++) {
for (int32_t y = 0; y < localSize[1]; y++) {
for (int32_t z = 0; z < localSize[2]; z++) {
ASSERT_NE(grid.get(x, y, z), nullptr);
}
}
}

ASSERT_NE(grid.get(-numGhostCells, 0, 0), nullptr);
ASSERT_EQ(grid.get(-numGhostCells - 1, 0, 0), nullptr);
ASSERT_EQ(grid.get(grid.getLocalSize()[0] + numGhostCells, 0, 0), nullptr);
ASSERT_NE(grid.get(grid.getLocalSize()[0] + numGhostCells - 1, 0, 0), nullptr);

ASSERT_NE(grid.get(0, -numGhostCells, 0), nullptr);
ASSERT_EQ(grid.get(0, -numGhostCells - 1, 0), nullptr);
ASSERT_EQ(grid.get(0, grid.getLocalSize()[1] + numGhostCells, 0), nullptr);
ASSERT_NE(grid.get(0, grid.getLocalSize()[1] + numGhostCells - 1, 0), nullptr);

ASSERT_NE(grid.get(0, 0, -numGhostCells), nullptr);
ASSERT_EQ(grid.get(0, 0, -numGhostCells - 1), nullptr);
ASSERT_NE(grid.get(0, 0, grid.getLocalSize()[2] + numGhostCells - 1), nullptr);
ASSERT_EQ(grid.get(0, 0, grid.getLocalSize()[2] + numGhostCells), nullptr);
}

0 comments on commit c8d3e35

Please sign in to comment.