diff --git a/src/atlas/interpolation/method/paralleltransport/ParallelTransport.cc b/src/atlas/interpolation/method/paralleltransport/ParallelTransport.cc index 30c3820f3..12f68b300 100644 --- a/src/atlas/interpolation/method/paralleltransport/ParallelTransport.cc +++ b/src/atlas/interpolation/method/paralleltransport/ParallelTransport.cc @@ -27,7 +27,6 @@ #include "eckit/linalg/Triplet.h" - namespace atlas { namespace interpolation { namespace method { @@ -35,7 +34,7 @@ namespace method { namespace { MethodBuilder __builder("parallel-transport"); -template +template void spaceMatrixForEach(MatrixT&& matrix, const Functor& functor) { const auto nRows = matrix.rows(); @@ -44,17 +43,21 @@ void spaceMatrixForEach(MatrixT&& matrix, const Functor& functor) { const auto colIndices = matrix.inner(); auto valData = matrix.data(); - atlas_omp_parallel_for (auto i = size_t{}; i < nRows; ++i) { - for (auto dataIdx = rowIndices[i]; dataIdx < rowIndices[i+1]; ++dataIdx) { + atlas_omp_parallel_for(auto i = size_t{}; i < nRows; ++i) { + for (auto dataIdx = rowIndices[i]; dataIdx < rowIndices[i + 1]; ++dataIdx) { const auto j = size_t(colIndices[dataIdx]); - auto& value = valData[dataIdx]; - - if constexpr (std::is_invocable_v) { - functor(value, i, j); - } - else if constexpr (std::is_invocable_v) { - functor(value, i, j, dataIdx); - } + auto&& value = valData[dataIdx]; + + if + constexpr( + std::is_invocable_v) { + functor(value, i, j); + } + else if + constexpr(std::is_invocable_v) { + functor(value, i, j, dataIdx); + } else { ATLAS_NOTIMPLEMENTED; } @@ -62,13 +65,47 @@ void spaceMatrixForEach(MatrixT&& matrix, const Functor& functor) { } } +template +void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView, + TargetView&& targetView, const Functor& mappingFunctor) { + + spaceMatrixForEach(matrix, [&](const auto& weight, auto i, auto j) { + + constexpr auto rank = std::decay_t::rank(); + if + constexpr(rank == 2) { + const auto sourceSlice = sourceView.slice(j, array::Range::all()); + auto targetSlice = targetView.slice(i, array::Range::all()); + mappingFunctor(weight, sourceSlice, targetSlice); + } + else if + constexpr(rank == 3) { + const auto iterationFuctor = [&](auto&& sourceVars, auto&& targetVars) { + mappingFunctor(weight, sourceVars, targetVars); + }; + const auto sourceSlice = + sourceView.slice(j, array::Range::all(), array::Range::all()); + auto targetSlice = + targetView.slice(i, array::Range::all(), array::Range::all()); + array::helpers::ArrayForEach<0>::apply( + std::tie(sourceSlice, targetSlice), iterationFuctor); + } + else { + ATLAS_NOTIMPLEMENTED; + } + }); +} + } // namespace -void ParallelTransport::do_setup(const Grid& source, const Grid& target, const Cache&) { +void ParallelTransport::do_setup(const Grid& source, const Grid& target, + const Cache&) { ATLAS_NOTIMPLEMENTED; } -void ParallelTransport::do_setup(const FunctionSpace& source, const FunctionSpace& target) { +void ParallelTransport::do_setup(const FunctionSpace& source, + const FunctionSpace& target) { ATLAS_TRACE("interpolation::method::ParallelTransport::do_setup"); source_ = source; target_ = target; @@ -77,7 +114,8 @@ void ParallelTransport::do_setup(const FunctionSpace& source, const FunctionSpac return; } - const auto baseInterpolator = Interpolation(interpolationScheme_, source_, target_); + const auto baseInterpolator = + Interpolation(interpolationScheme_, source_, target_); setMatrix(MatrixCache(baseInterpolator)); // Get matrix dimensions. @@ -85,49 +123,44 @@ void ParallelTransport::do_setup(const FunctionSpace& source, const FunctionSpac const auto nCols = matrix().cols(); const auto nNonZeros = matrix().nonZeros(); - auto weights00 = std::vector(nNonZeros); - auto weights01 = std::vector(nNonZeros); - auto weights10 = std::vector(nNonZeros); - auto weights11 = std::vector(nNonZeros); + auto weightsReal = std::vector(nNonZeros); + auto weightsImag = std::vector(nNonZeros); const auto sourceLonLats = array::make_view(source_.lonlat()); const auto targetLonLats = array::make_view(target_.lonlat()); - - spaceMatrixForEach(matrix(), [&](auto&& weight, auto i, auto j, auto dataIdx){ - const auto sourceLonLat = PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1)); - const auto targetLonLat = PointLonLat(targetLonLats(i, 0), targetLonLats(i, 1)); + // Make complex weights (would be nice if we could have a complex matrix). + spaceMatrixForEach(matrix(), + [&](auto&& weight, auto i, auto j, auto dataIdx) { + const auto sourceLonLat = + PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1)); + const auto targetLonLat = + PointLonLat(targetLonLats(i, 0), targetLonLats(i, 1)); const auto alpha = util::greatCircleCourse(sourceLonLat, targetLonLat); - auto deltaAlpha = (alpha.first - alpha.second) * util::Constants::degreesToRadians(); + auto deltaAlpha = + (alpha.first - alpha.second) * util::Constants::degreesToRadians(); - weights00[dataIdx] = {i, j, weight * std::cos(deltaAlpha)}; - weights01[dataIdx] = {i, j, -weight * std::sin(deltaAlpha)}; - weights10[dataIdx] = {i, j, weight * std::sin(deltaAlpha)}; - weights11[dataIdx] = {i, j, weight * std::cos(deltaAlpha)}; + weightsReal[dataIdx] = {i, j, weight * std::cos(deltaAlpha)}; + weightsImag[dataIdx] = {i, j, weight * std::sin(deltaAlpha)}; }); - // Deal with slightly old fashioned Matrix interface - const auto buildMatrix = [&](auto& matrix, const auto& weights){ + const auto buildMatrix = [&](auto& matrix, const auto& weights) { auto tempMatrix = Matrix(nRows, nCols, weights); matrix.swap(tempMatrix); }; - buildMatrix(matrix00_, weights00); - buildMatrix(matrix10_, weights10); - buildMatrix(matrix01_, weights01); - buildMatrix(matrix11_, weights11); - + buildMatrix(matrixReal_, weightsReal); + buildMatrix(matrixImag_, weightsImag); } -void ParallelTransport::print(std::ostream&) const { - ATLAS_NOTIMPLEMENTED; -} +void ParallelTransport::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; } -void ParallelTransport::do_execute(const FieldSet& sourceFieldSet, FieldSet& targetFieldSet, Metadata& metadata) const -{ +void ParallelTransport::do_execute(const FieldSet& sourceFieldSet, + FieldSet& targetFieldSet, + Metadata& metadata) const { ATLAS_TRACE("atlas::interpolation::method::ParallelTransport::do_execute()"); const auto nFields = sourceFieldSet.size(); @@ -138,8 +171,8 @@ void ParallelTransport::do_execute(const FieldSet& sourceFieldSet, FieldSet& tar } } -void ParallelTransport::do_execute(const Field& sourceField, Field& targetField, Metadata &) const -{ +void ParallelTransport::do_execute(const Field& sourceField, Field& targetField, + Metadata&) const { ATLAS_TRACE("atlas::interpolation::method::ParallelTransport::do_execute()"); if (!(sourceField.variables() == 2 || sourceField.variables() == 3)) { @@ -148,7 +181,6 @@ void ParallelTransport::do_execute(const Field& sourceField, Field& targetField, Method::do_execute(sourceField, targetField, metadata); return; - } if (target_.size() == 0) { @@ -159,75 +191,54 @@ void ParallelTransport::do_execute(const Field& sourceField, Field& targetField, if (sourceField.datatype().kind() == array::DataType::KIND_REAL64) { interpolate_vector_field(sourceField, targetField); - } - else if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) { + } else if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) { interpolate_vector_field(sourceField, targetField); - } - else { + } else { ATLAS_NOTIMPLEMENTED; } targetField.set_dirty(); } -template -void ParallelTransport::interpolate_vector_field(const Field& sourceField, Field& targetField) const -{ +template +void ParallelTransport::interpolate_vector_field(const Field& sourceField, + Field& targetField) const { if (sourceField.rank() == 2) { - interpolate_vector_field(sourceField, targetField); - } - else if (sourceField.rank() == 3 ) { + interpolate_vector_field(sourceField, targetField); + } else if (sourceField.rank() == 3) { interpolate_vector_field(sourceField, targetField); - } - else { + } else { ATLAS_NOTIMPLEMENTED; } } -template -void ParallelTransport::interpolate_vector_field(const Field& sourceField, Field& targetField) const -{ - using namespace linalg; +template +void ParallelTransport::interpolate_vector_field(const Field& sourceField, + Field& targetField) const { const auto sourceView = array::make_view(sourceField); auto targetView = array::make_view(targetField); - - array::make_view(targetField).assign(0); - - - - const auto matrixMultiply = [&](const auto& matrix, auto sourceVariableIdx, auto targetVariableIdx) { - - spaceMatrixForEach(matrix, [&](const auto& weight, auto i, auto j){ - - const auto adder = [&](auto&& targetElem, auto&& sourceElem){ - targetElem += weight * sourceElem; - }; - - if constexpr (Rank == 2) { - adder(targetView(i, targetVariableIdx), sourceView(j, sourceVariableIdx)); - } - else if constexpr (Rank == 3) { - const auto sourceSlice = sourceView.slice(j, array::Range::all(), sourceVariableIdx); - auto targetSlice = targetView.slice(i, array::Range::all(), targetVariableIdx); - array::helpers::ArrayForEach<0>::apply(std::tie(targetSlice, sourceSlice), adder); - } - else { - ATLAS_NOTIMPLEMENTED; - } - - }); - }; - - matrixMultiply(matrix00_, 0, 0); - matrixMultiply(matrix01_, 1, 0); - matrixMultiply(matrix10_, 0, 1); - matrixMultiply(matrix11_, 1, 1); + targetView.assign(0.); + + // Matrix multiplication split in two to simulate complex variable + // multiplication. + matrixMultiply(matrixReal_, sourceView, targetView, + [](const auto& weight, auto&& sourceVars, auto&& targetVars) { + targetVars(0) += weight * sourceVars(0); + targetVars(1) += weight * sourceVars(1); + }); + matrixMultiply(matrixImag_, sourceView, targetView, + [](const auto& weight, auto&& sourceVars, auto&& targetVars) { + targetVars(0) -= weight * sourceVars(1); + targetVars(1) += weight * sourceVars(0); + }); if (sourceField.variables() == 3) { - matrixMultiply(matrix(), 2, 2); + matrixMultiply( + matrix(), sourceView, targetView, + [](const auto& weight, auto&& sourceVars, + auto&& targetVars) { targetVars(2) = weight * sourceVars(2); }); } - } } // namespace method diff --git a/src/atlas/interpolation/method/paralleltransport/ParallelTransport.h b/src/atlas/interpolation/method/paralleltransport/ParallelTransport.h index 7384077c8..fbc427104 100644 --- a/src/atlas/interpolation/method/paralleltransport/ParallelTransport.h +++ b/src/atlas/interpolation/method/paralleltransport/ParallelTransport.h @@ -50,11 +50,10 @@ class ParallelTransport : public Method { FunctionSpace source_; FunctionSpace target_; - // Matrices for vector field interpolation - Matrix matrix00_; - Matrix matrix01_; - Matrix matrix10_; - Matrix matrix11_; + // Complex interpolation weights. We treat a (u, v) pair as a complex variable + // with u on the real number line and v on the imaginary number line. + Matrix matrixReal_; + Matrix matrixImag_; }; diff --git a/src/tests/interpolation/test_interpolation_parallel_transport.cc b/src/tests/interpolation/test_interpolation_parallel_transport.cc index 98b87dcea..23967c6f8 100644 --- a/src/tests/interpolation/test_interpolation_parallel_transport.cc +++ b/src/tests/interpolation/test_interpolation_parallel_transport.cc @@ -177,7 +177,7 @@ CASE("finite element vector interpolation") { .set("target_grid", "CS-LFR-48") .set("target_mesh", "cubedsphere_dual") .set("file_id", "parallel_transport_fe") - .set("scheme", baseInterpScheme); + .set("scheme", interpScheme); testInterpolation((cubedSphereConf)); }