Skip to content

Commit

Permalink
Forced consistency between integral types.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Feb 9, 2024
1 parent 231fc16 commit b376c96
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ComplexMatrixMultiply {
ATLAS_ASSERT(complexWeightsPtr_->cols() == realWeightsPtr_->cols());
ATLAS_ASSERT(complexWeightsPtr_->nonZeros() ==
realWeightsPtr_->nonZeros());
for (auto rowIndex = Size{0}; rowIndex < complexWeightsPtr_->rows();
for (auto rowIndex = Index{0}; rowIndex < complexWeightsPtr_->rows();
++rowIndex) {
for (auto [complexRowIter, realRowIter] = rowIters(rowIndex);
complexRowIter; ++complexRowIter, ++realRowIter) {
Expand Down Expand Up @@ -100,7 +100,7 @@ class ComplexMatrixMultiply {
// We could probably optimise contiguous arrays using
// reinterpret_cast<std::complex<double>*>(view.data()). This is fine
// according to the C++ standard!
atlas_omp_parallel_for(auto rowIndex = Size{0};
atlas_omp_parallel_for(auto rowIndex = Index{0};
rowIndex < complexWeightsPtr_->rows(); ++rowIndex) {
auto targetSlice = sliceColumn(targetView, rowIndex);
if constexpr (InitialiseTarget) {
Expand Down Expand Up @@ -129,7 +129,7 @@ class ComplexMatrixMultiply {
template <typename Value, int Rank>
void applyThreeVector(const array::ArrayView<const Value, Rank>& sourceView,
array::ArrayView<Value, Rank>& targetView) const {
atlas_omp_parallel_for(auto rowIndex = Size{0};
atlas_omp_parallel_for(auto rowIndex = Index{0};
rowIndex < complexWeightsPtr_->rows(); ++rowIndex) {
auto targetSlice = sliceColumn(targetView, rowIndex);
if constexpr (InitialiseTarget) {
Expand Down Expand Up @@ -158,7 +158,7 @@ class ComplexMatrixMultiply {

/// @brief Return a pair of complex and real row iterators
std::pair<ComplexMatrix::RowIter, RealMatrix::RowIter> rowIters(
Size rowIndex) const {
Index rowIndex) const {
return std::make_pair(complexWeightsPtr_->rowIter(rowIndex),
realWeightsPtr_->rowIter(rowIndex));
}
Expand Down
26 changes: 12 additions & 14 deletions src/atlas/interpolation/method/sphericalvector/SparseMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#endif

#include "atlas/runtime/Exception.h"
#include "eckit/log/CodeLocation.h"

namespace atlas {
namespace interpolation {
Expand All @@ -31,12 +32,10 @@ namespace detail {
/// class is Eigen library is not present.
template <typename Value>
class SparseMatrix {
using EigenMatrix = Eigen::SparseMatrix<Value, Eigen::RowMajor>;

public:
using Index = typename EigenMatrix::StorageIndex;
using Size = typename EigenMatrix::Index;
using Triplet = Eigen::Triplet<Value>;
using Index = int;
using EigenMatrix = Eigen::SparseMatrix<Value, Eigen::RowMajor, Index>;
using Triplet = Eigen::Triplet<Value, Index>;
using Triplets = std::vector<Triplet>;
using RowIter = typename EigenMatrix::InnerIterator;

Expand All @@ -45,10 +44,10 @@ class SparseMatrix {
eigenMatrix_.setFromTriplets(triplets.begin(), triplets.end());
}

Size nonZeros() const { return eigenMatrix_.nonZeros(); }
Size rows() const { return eigenMatrix_.rows(); }
Size cols() const { return eigenMatrix_.cols(); }
RowIter rowIter(Size rowIndex) const {
Index nonZeros() const { return eigenMatrix_.nonZeros(); }
Index rows() const { return eigenMatrix_.rows(); }
Index cols() const { return eigenMatrix_.cols(); }
RowIter rowIter(Index rowIndex) const {
return RowIter(eigenMatrix_, rowIndex);
}
SparseMatrix<Value> adjoint() const {
Expand All @@ -66,7 +65,6 @@ template <typename Value>
class SparseMatrix {
public:
using Index = int;
using Size = long int;

class Triplet {
public:
Expand All @@ -90,10 +88,10 @@ class SparseMatrix {
SparseMatrix(const Args&... args) {
throw_Exception("Atlas has been compiled without Eigen", Here());
}
constexpr Size nonZeros() const { return Size{}; }
constexpr Size rows() const { return Size{}; }
constexpr Size cols() const { return Size{}; }
constexpr RowIter rowIter(Size rowIndex) const { return RowIter{}; }
constexpr Index nonZeros() const { return Index{}; }
constexpr Index rows() const { return Index{}; }
constexpr Index cols() const { return Index{}; }
constexpr RowIter rowIter(Index rowIndex) const { return RowIter{}; }
constexpr SparseMatrix<Value> adjoint() const {
return SparseMatrix<Value>{};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ using namespace detail;

SphericalVector::SphericalVector(const Config& config) : Method(config) {
const auto& conf = dynamic_cast<const eckit::LocalConfiguration&>(config);
ATLAS_ASSERT_MSG(&conf,
"config must be castable to an eckit::LocalConfiguration");
interpolationScheme_ = conf.getSubConfiguration("scheme");
adjoint_ = conf.getBool("adjoint", false);
}
Expand All @@ -56,9 +58,9 @@ void SphericalVector::do_setup(const FunctionSpace& source,
setMatrix(Interpolation(interpolationScheme_, source_, target_));

// Get matrix data.
const auto nRows = matrix().rows();
const auto nCols = matrix().cols();
const auto nNonZeros = matrix().nonZeros();
const auto nRows = static_cast<Index>(matrix().rows());
const auto nCols = static_cast<Index>(matrix().cols());
const auto nNonZeros = static_cast<std::size_t>(matrix().nonZeros());
const auto* outerIndices = matrix().outer();
const auto* innerIndices = matrix().inner();
const auto* baseWeights = matrix().data();
Expand Down Expand Up @@ -87,7 +89,7 @@ void SphericalVector::do_setup(const FunctionSpace& source,
++rowIndex) {
for (auto dataIndex = outerIndices[rowIndex];
dataIndex < outerIndices[rowIndex + 1]; ++dataIndex) {
const auto colIndex = innerIndices[dataIndex];
const auto colIndex = static_cast<Index>(innerIndices[dataIndex]);
const auto baseWeight = baseWeights[dataIndex];

const auto sourceLonLat = PointLonLat(sourceLonLatsView(colIndex, 0),
Expand Down
1 change: 0 additions & 1 deletion src/atlas/interpolation/method/sphericalvector/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ using Complex = std::complex<Real>;
using ComplexMatrix = SparseMatrix<Complex>;
using RealMatrix = SparseMatrix<Real>;
using Index = ComplexMatrix::Index;
using Size = ComplexMatrix::Size;
using ComplexTriplet = ComplexMatrix::Triplet;
using ComplexTriplets = ComplexMatrix::Triplets;
using RealTriplet = RealMatrix::Triplet;
Expand Down

0 comments on commit b376c96

Please sign in to comment.