Skip to content

Commit

Permalink
Replaced rotation matrices with complex interpolation weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Nov 21, 2023
1 parent 7c0a5d7 commit d01f5dd
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 99 deletions.
197 changes: 104 additions & 93 deletions src/atlas/interpolation/method/paralleltransport/ParallelTransport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@

#include "eckit/linalg/Triplet.h"


namespace atlas {
namespace interpolation {
namespace method {

namespace {
MethodBuilder<ParallelTransport> __builder("parallel-transport");

template<typename MatrixT, typename Functor>
template <typename MatrixT, typename Functor>
void spaceMatrixForEach(MatrixT&& matrix, const Functor& functor) {

const auto nRows = matrix.rows();
Expand All @@ -44,31 +43,69 @@ void spaceMatrixForEach(MatrixT&& matrix, const Functor& functor) {
const auto colIndices = matrix.inner();
auto valData = matrix.data();

atlas_omp_parallel_for (auto i = size_t{}; i < nRows; ++i) {
for (auto dataIdx = rowIndices[i]; dataIdx < rowIndices[i+1]; ++dataIdx) {
atlas_omp_parallel_for(auto i = size_t{}; i < nRows; ++i) {
for (auto dataIdx = rowIndices[i]; dataIdx < rowIndices[i + 1]; ++dataIdx) {
const auto j = size_t(colIndices[dataIdx]);
auto& value = valData[dataIdx];

if constexpr (std::is_invocable_v<Functor, decltype(value), size_t, size_t>) {
functor(value, i, j);
}
else if constexpr (std::is_invocable_v<Functor, decltype(value), size_t, size_t, size_t>) {
functor(value, i, j, dataIdx);
}
auto&& value = valData[dataIdx];

if
constexpr(
std::is_invocable_v<Functor, decltype(value), size_t, size_t>) {
functor(value, i, j);
}
else if
constexpr(std::is_invocable_v<Functor, decltype(value), size_t, size_t,
size_t>) {
functor(value, i, j, dataIdx);
}
else {
ATLAS_NOTIMPLEMENTED;
}
}
}
}

template <typename MatrixT, typename SourceView, typename TargetView,
typename Functor>
void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView,
TargetView&& targetView, const Functor& mappingFunctor) {

spaceMatrixForEach(matrix, [&](const auto& weight, auto i, auto j) {

constexpr auto rank = std::decay_t<decltype(sourceView)>::rank();
if
constexpr(rank == 2) {
const auto sourceSlice = sourceView.slice(j, array::Range::all());
auto targetSlice = targetView.slice(i, array::Range::all());
mappingFunctor(weight, sourceSlice, targetSlice);
}
else if
constexpr(rank == 3) {
const auto iterationFuctor = [&](auto&& sourceVars, auto&& targetVars) {
mappingFunctor(weight, sourceVars, targetVars);
};
const auto sourceSlice =
sourceView.slice(j, array::Range::all(), array::Range::all());
auto targetSlice =
targetView.slice(i, array::Range::all(), array::Range::all());
array::helpers::ArrayForEach<0>::apply(
std::tie(sourceSlice, targetSlice), iterationFuctor);
}
else {
ATLAS_NOTIMPLEMENTED;
}
});
}

} // namespace

void ParallelTransport::do_setup(const Grid& source, const Grid& target, const Cache&) {
void ParallelTransport::do_setup(const Grid& source, const Grid& target,
const Cache&) {
ATLAS_NOTIMPLEMENTED;
}

void ParallelTransport::do_setup(const FunctionSpace& source, const FunctionSpace& target) {
void ParallelTransport::do_setup(const FunctionSpace& source,
const FunctionSpace& target) {
ATLAS_TRACE("interpolation::method::ParallelTransport::do_setup");
source_ = source;
target_ = target;
Expand All @@ -77,57 +114,53 @@ void ParallelTransport::do_setup(const FunctionSpace& source, const FunctionSpac
return;
}

const auto baseInterpolator = Interpolation(interpolationScheme_, source_, target_);
const auto baseInterpolator =
Interpolation(interpolationScheme_, source_, target_);
setMatrix(MatrixCache(baseInterpolator));

// Get matrix dimensions.
const auto nRows = matrix().rows();
const auto nCols = matrix().cols();
const auto nNonZeros = matrix().nonZeros();

auto weights00 = std::vector<eckit::linalg::Triplet>(nNonZeros);
auto weights01 = std::vector<eckit::linalg::Triplet>(nNonZeros);
auto weights10 = std::vector<eckit::linalg::Triplet>(nNonZeros);
auto weights11 = std::vector<eckit::linalg::Triplet>(nNonZeros);
auto weightsReal = std::vector<eckit::linalg::Triplet>(nNonZeros);
auto weightsImag = std::vector<eckit::linalg::Triplet>(nNonZeros);

const auto sourceLonLats = array::make_view<double, 2>(source_.lonlat());
const auto targetLonLats = array::make_view<double, 2>(target_.lonlat());


spaceMatrixForEach(matrix(), [&](auto&& weight, auto i, auto j, auto dataIdx){
const auto sourceLonLat = PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1));
const auto targetLonLat = PointLonLat(targetLonLats(i, 0), targetLonLats(i, 1));
// Make complex weights (would be nice if we could have a complex matrix).
spaceMatrixForEach(matrix(),
[&](auto&& weight, auto i, auto j, auto dataIdx) {
const auto sourceLonLat =
PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1));
const auto targetLonLat =
PointLonLat(targetLonLats(i, 0), targetLonLats(i, 1));

const auto alpha = util::greatCircleCourse(sourceLonLat, targetLonLat);

auto deltaAlpha = (alpha.first - alpha.second) * util::Constants::degreesToRadians();
auto deltaAlpha =
(alpha.first - alpha.second) * util::Constants::degreesToRadians();

weights00[dataIdx] = {i, j, weight * std::cos(deltaAlpha)};
weights01[dataIdx] = {i, j, -weight * std::sin(deltaAlpha)};
weights10[dataIdx] = {i, j, weight * std::sin(deltaAlpha)};
weights11[dataIdx] = {i, j, weight * std::cos(deltaAlpha)};
weightsReal[dataIdx] = {i, j, weight * std::cos(deltaAlpha)};
weightsImag[dataIdx] = {i, j, weight * std::sin(deltaAlpha)};
});


// Deal with slightly old fashioned Matrix interface
const auto buildMatrix = [&](auto& matrix, const auto& weights){
const auto buildMatrix = [&](auto& matrix, const auto& weights) {
auto tempMatrix = Matrix(nRows, nCols, weights);
matrix.swap(tempMatrix);
};

buildMatrix(matrix00_, weights00);
buildMatrix(matrix10_, weights10);
buildMatrix(matrix01_, weights01);
buildMatrix(matrix11_, weights11);

buildMatrix(matrixReal_, weightsReal);
buildMatrix(matrixImag_, weightsImag);
}

void ParallelTransport::print(std::ostream&) const {
ATLAS_NOTIMPLEMENTED;
}
void ParallelTransport::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; }

void ParallelTransport::do_execute(const FieldSet& sourceFieldSet, FieldSet& targetFieldSet, Metadata& metadata) const
{
void ParallelTransport::do_execute(const FieldSet& sourceFieldSet,
FieldSet& targetFieldSet,
Metadata& metadata) const {
ATLAS_TRACE("atlas::interpolation::method::ParallelTransport::do_execute()");

const auto nFields = sourceFieldSet.size();
Expand All @@ -138,8 +171,8 @@ void ParallelTransport::do_execute(const FieldSet& sourceFieldSet, FieldSet& tar
}
}

void ParallelTransport::do_execute(const Field& sourceField, Field& targetField, Metadata &) const
{
void ParallelTransport::do_execute(const Field& sourceField, Field& targetField,
Metadata&) const {
ATLAS_TRACE("atlas::interpolation::method::ParallelTransport::do_execute()");

if (!(sourceField.variables() == 2 || sourceField.variables() == 3)) {
Expand All @@ -148,7 +181,6 @@ void ParallelTransport::do_execute(const Field& sourceField, Field& targetField,
Method::do_execute(sourceField, targetField, metadata);

return;

}

if (target_.size() == 0) {
Expand All @@ -159,75 +191,54 @@ void ParallelTransport::do_execute(const Field& sourceField, Field& targetField,

if (sourceField.datatype().kind() == array::DataType::KIND_REAL64) {
interpolate_vector_field<double>(sourceField, targetField);
}
else if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) {
} else if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) {
interpolate_vector_field<float>(sourceField, targetField);
}
else {
} else {
ATLAS_NOTIMPLEMENTED;
}

targetField.set_dirty();
}

template<typename Value>
void ParallelTransport::interpolate_vector_field(const Field& sourceField, Field& targetField) const
{
template <typename Value>
void ParallelTransport::interpolate_vector_field(const Field& sourceField,
Field& targetField) const {
if (sourceField.rank() == 2) {
interpolate_vector_field<Value, 2>(sourceField, targetField);
}
else if (sourceField.rank() == 3 ) {
interpolate_vector_field<Value, 2>(sourceField, targetField);
} else if (sourceField.rank() == 3) {
interpolate_vector_field<Value, 3>(sourceField, targetField);
}
else {
} else {
ATLAS_NOTIMPLEMENTED;
}
}

template<typename Value, int Rank>
void ParallelTransport::interpolate_vector_field(const Field& sourceField, Field& targetField) const
{
using namespace linalg;
template <typename Value, int Rank>
void ParallelTransport::interpolate_vector_field(const Field& sourceField,
Field& targetField) const {

const auto sourceView = array::make_view<Value, Rank>(sourceField);
auto targetView = array::make_view<Value, Rank>(targetField);

array::make_view<Value, Rank>(targetField).assign(0);



const auto matrixMultiply = [&](const auto& matrix, auto sourceVariableIdx, auto targetVariableIdx) {

spaceMatrixForEach(matrix, [&](const auto& weight, auto i, auto j){

const auto adder = [&](auto&& targetElem, auto&& sourceElem){
targetElem += weight * sourceElem;
};

if constexpr (Rank == 2) {
adder(targetView(i, targetVariableIdx), sourceView(j, sourceVariableIdx));
}
else if constexpr (Rank == 3) {
const auto sourceSlice = sourceView.slice(j, array::Range::all(), sourceVariableIdx);
auto targetSlice = targetView.slice(i, array::Range::all(), targetVariableIdx);
array::helpers::ArrayForEach<0>::apply(std::tie(targetSlice, sourceSlice), adder);
}
else {
ATLAS_NOTIMPLEMENTED;
}

});
};

matrixMultiply(matrix00_, 0, 0);
matrixMultiply(matrix01_, 1, 0);
matrixMultiply(matrix10_, 0, 1);
matrixMultiply(matrix11_, 1, 1);
targetView.assign(0.);

// Matrix multiplication split in two to simulate complex variable
// multiplication.
matrixMultiply(matrixReal_, sourceView, targetView,
[](const auto& weight, auto&& sourceVars, auto&& targetVars) {
targetVars(0) += weight * sourceVars(0);
targetVars(1) += weight * sourceVars(1);
});
matrixMultiply(matrixImag_, sourceView, targetView,
[](const auto& weight, auto&& sourceVars, auto&& targetVars) {
targetVars(0) -= weight * sourceVars(1);
targetVars(1) += weight * sourceVars(0);
});

if (sourceField.variables() == 3) {
matrixMultiply(matrix(), 2, 2);
matrixMultiply(
matrix(), sourceView, targetView,
[](const auto& weight, auto&& sourceVars,
auto&& targetVars) { targetVars(2) = weight * sourceVars(2); });
}

}

} // namespace method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ class ParallelTransport : public Method {
FunctionSpace source_;
FunctionSpace target_;

// Matrices for vector field interpolation
Matrix matrix00_;
Matrix matrix01_;
Matrix matrix10_;
Matrix matrix11_;
// Complex interpolation weights. We treat a (u, v) pair as a complex variable
// with u on the real number line and v on the imaginary number line.
Matrix matrixReal_;
Matrix matrixImag_;

};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ CASE("finite element vector interpolation") {
.set("target_grid", "CS-LFR-48")
.set("target_mesh", "cubedsphere_dual")
.set("file_id", "parallel_transport_fe")
.set("scheme", baseInterpScheme);
.set("scheme", interpScheme);

testInterpolation((cubedSphereConf));
}
Expand Down

0 comments on commit d01f5dd

Please sign in to comment.