From c634ec63e7fdfd8c84140df25d71ec246e8ab667 Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Wed, 20 Nov 2024 15:21:53 +0000 Subject: [PATCH] Use atlas::linalg::SparseMatrix class instead of eckit::linalg::SparseMatrix. --- src/atlas/interpolation/Cache.h | 4 ++-- src/atlas/interpolation/method/Method.h | 4 ++-- src/atlas/interpolation/method/binning/Binning.cc | 4 ++-- .../interpolation/method/knn/GridBoxMaximum.cc | 2 +- .../ConservativeSphericalPolygonInterpolation.cc | 4 ++-- src/atlas/interpolation/nonlinear/Missing.cc | 15 ++++++++++++--- src/atlas/interpolation/nonlinear/NonLinear.h | 8 ++++---- src/atlas/linalg/sparse/SparseMatrixMultiply.h | 3 +-- .../sparse/SparseMatrixMultiply_EckitLinalg.cc | 4 ++-- .../test_interpolation_finite_element_cached.cc | 4 ++-- 10 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/atlas/interpolation/Cache.h b/src/atlas/interpolation/Cache.h index 2e70f827c..8c2b9d9bf 100644 --- a/src/atlas/interpolation/Cache.h +++ b/src/atlas/interpolation/Cache.h @@ -17,8 +17,8 @@ #include "eckit/filesystem/PathName.h" #include "eckit/io/Buffer.h" -#include "eckit/linalg/SparseMatrix.h" +#include "atlas/linalg/SparseMatrix.h" #include "atlas/runtime/Exception.h" #include "atlas/util/KDTree.h" @@ -82,7 +82,7 @@ class Cache { class MatrixCacheEntry : public InterpolationCacheEntry { public: - using Matrix = eckit::linalg::SparseMatrix; + using Matrix = atlas::linalg::SparseMatrix; ~MatrixCacheEntry() override; MatrixCacheEntry(const Matrix* matrix, const std::string& uid = ""): matrix_{matrix}, uid_(uid) { ATLAS_ASSERT(matrix_ != nullptr); diff --git a/src/atlas/interpolation/method/Method.h b/src/atlas/interpolation/method/Method.h index 7e52d5bf3..5b37ca25c 100644 --- a/src/atlas/interpolation/method/Method.h +++ b/src/atlas/interpolation/method/Method.h @@ -16,10 +16,10 @@ #include "atlas/interpolation/Cache.h" #include "atlas/interpolation/NonLinear.h" +#include "atlas/linalg/SparseMatrix.h" #include "atlas/util/Metadata.h" #include "atlas/util/Object.h" #include "eckit/config/Configuration.h" -#include "eckit/linalg/SparseMatrix.h" namespace atlas { class Field; @@ -87,7 +87,7 @@ class Method : public util::Object { using Triplet = eckit::linalg::Triplet; using Triplets = std::vector; - using Matrix = eckit::linalg::SparseMatrix; + using Matrix = atlas::linalg::SparseMatrix; static void normalise(Triplets& triplets); diff --git a/src/atlas/interpolation/method/binning/Binning.cc b/src/atlas/interpolation/method/binning/Binning.cc index f5a6b0743..902072f4f 100644 --- a/src/atlas/interpolation/method/binning/Binning.cc +++ b/src/atlas/interpolation/method/binning/Binning.cc @@ -12,12 +12,12 @@ #include "atlas/interpolation/Interpolation.h" #include "atlas/interpolation/method/binning/Binning.h" #include "atlas/interpolation/method/MethodFactory.h" +#include "atlas/linalg/SparseMatrix.h" #include "atlas/mesh.h" #include "atlas/mesh/actions/GetCubedSphereNodalArea.h" #include "atlas/runtime/Trace.h" #include "eckit/config/LocalConfiguration.h" -#include "eckit/linalg/SparseMatrix.h" #include "eckit/linalg/Triplet.h" #include "eckit/mpi/Comm.h" @@ -58,7 +58,7 @@ void Binning::do_setup(const FunctionSpace& source, using Index = eckit::linalg::Index; using Triplet = eckit::linalg::Triplet; - using SMatrix = eckit::linalg::SparseMatrix; + using SMatrix = atlas::linalg::SparseMatrix; source_ = source; target_ = target; diff --git a/src/atlas/interpolation/method/knn/GridBoxMaximum.cc b/src/atlas/interpolation/method/knn/GridBoxMaximum.cc index 8ca86672a..6a3435478 100644 --- a/src/atlas/interpolation/method/knn/GridBoxMaximum.cc +++ b/src/atlas/interpolation/method/knn/GridBoxMaximum.cc @@ -61,7 +61,7 @@ void GridBoxMaximum::do_execute(const Field& source, Field& target, Metadata&) c if (!matrixFree_) { const Matrix& m = matrix(); - Matrix::const_iterator k(m); + Matrix::const_iterator k = m.begin(); for (decltype(m.rows()) i = 0, j = 0; i < m.rows(); ++i) { double max = std::numeric_limits::lowest(); diff --git a/src/atlas/interpolation/method/unstructured/ConservativeSphericalPolygonInterpolation.cc b/src/atlas/interpolation/method/unstructured/ConservativeSphericalPolygonInterpolation.cc index d1ea5addf..334b65422 100644 --- a/src/atlas/interpolation/method/unstructured/ConservativeSphericalPolygonInterpolation.cc +++ b/src/atlas/interpolation/method/unstructured/ConservativeSphericalPolygonInterpolation.cc @@ -1248,7 +1248,7 @@ void ConservativeSphericalPolygonInterpolation::intersect_polygons(const CSPolyg } } -eckit::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_1st_order_matrix() { +atlas::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_1st_order_matrix() { ATLAS_TRACE("ConservativeMethod::setup: build cons-1 interpolant matrix"); ATLAS_ASSERT(not matrix_free_); Triplets triplets; @@ -1358,7 +1358,7 @@ eckit::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_1 return Matrix(n_tpoints_, n_spoints_, triplets); } -eckit::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_2nd_order_matrix() { +atlas::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_2nd_order_matrix() { ATLAS_TRACE("ConservativeMethod::setup: build cons-2 interpolant matrix"); ATLAS_ASSERT(not matrix_free_); const auto& src_points_ = data_->src_points_; diff --git a/src/atlas/interpolation/nonlinear/Missing.cc b/src/atlas/interpolation/nonlinear/Missing.cc index 1258f9fdc..47c35f6b6 100644 --- a/src/atlas/interpolation/nonlinear/Missing.cc +++ b/src/atlas/interpolation/nonlinear/Missing.cc @@ -81,7 +81,7 @@ bool MissingIfAllMissing::executeT(NonLinear::Matrix& W, const Field& field) con bool zeros = false; Size i = 0; - Matrix::iterator it(W); + Matrix::iterator it = W.begin(); for (Size r = 0; r < W.rows(); ++r) { const Matrix::iterator end = W.end(r); @@ -128,6 +128,9 @@ bool MissingIfAllMissing::executeT(NonLinear::Matrix& W, const Field& field) con modif = true; } } + if (modif) { + W.setDeviceNeedsUpdate(true); + } if (zeros && missingValue.isnan()) { W.prune(0.); @@ -161,7 +164,7 @@ bool MissingIfAnyMissing::executeT(NonLinear::Matrix& W, const Field& field) con bool zeros = false; Size i = 0; - Matrix::iterator it(W); + Matrix::iterator it = W.begin(); for (Size r = 0; r < W.rows(); ++r) { const Matrix::iterator end = W.end(r); @@ -195,6 +198,9 @@ bool MissingIfAnyMissing::executeT(NonLinear::Matrix& W, const Field& field) con modif = true; } } + if (modif) { + W.setDeviceNeedsUpdate(true); + } if (zeros && missingValue.isnan()) { W.prune(0.); @@ -228,7 +234,7 @@ bool MissingIfHeaviestMissing::executeT(NonLinear::Matrix& W, const Field& field bool zeros = false; Size i = 0; - Matrix::iterator it(W); + Matrix::iterator it = W.begin(); for (Size r = 0; r < W.rows(); ++r) { const Matrix::iterator end = W.end(r); @@ -282,6 +288,9 @@ bool MissingIfHeaviestMissing::executeT(NonLinear::Matrix& W, const Field& field modif = true; } } + if (modif) { + W.setDeviceNeedsUpdate(true); + } if (zeros && missingValue.isnan()) { W.prune(0.); diff --git a/src/atlas/interpolation/nonlinear/NonLinear.h b/src/atlas/interpolation/nonlinear/NonLinear.h index dce24441e..7ce865434 100644 --- a/src/atlas/interpolation/nonlinear/NonLinear.h +++ b/src/atlas/interpolation/nonlinear/NonLinear.h @@ -16,10 +16,10 @@ #include #include "eckit/config/Parametrisation.h" -#include "eckit/linalg/SparseMatrix.h" #include "atlas/array.h" #include "atlas/field/Field.h" +#include "atlas/linalg/SparseMatrix.h" #include "atlas/runtime/Exception.h" #include "atlas/util/Factory.h" #include "atlas/util/ObjectHandle.h" @@ -37,9 +37,9 @@ namespace nonlinear { class NonLinear : public util::Object { public: using Config = eckit::Parametrisation; - using Matrix = eckit::linalg::SparseMatrix; - using Scalar = eckit::linalg::Scalar; - using Size = eckit::linalg::Size; + using Matrix = atlas::linalg::SparseMatrix; + using Scalar = atlas::linalg::SparseMatrix::Scalar; + using Size = atlas::linalg::SparseMatrix::Size; /** * @brief ctor diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply.h b/src/atlas/linalg/sparse/SparseMatrixMultiply.h index d1776975b..6d923ebb0 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply.h +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply.h @@ -11,9 +11,9 @@ #pragma once #include "eckit/config/Configuration.h" -#include "eckit/linalg/SparseMatrix.h" #include "atlas/linalg/Indexing.h" +#include "atlas/linalg/SparseMatrix.h" #include "atlas/linalg/View.h" #include "atlas/linalg/sparse/Backend.h" #include "atlas/runtime/Exception.h" @@ -22,7 +22,6 @@ namespace atlas { namespace linalg { -using SparseMatrix = eckit::linalg::SparseMatrix; using Configuration = eckit::Configuration; template diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc b/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc index 72d55c1a9..b9bb1f9b4 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc @@ -70,7 +70,7 @@ void SparseMatrixMultiply::apply( @@ -81,7 +81,7 @@ void SparseMatrixMultiply= W.rows()); eckit::linalg::Matrix m_src(src.data(), src.shape(1), src.shape(0)); eckit::linalg::Matrix m_tgt(tgt.data(), tgt.shape(1), tgt.shape(0)); - eckit_linalg_backend(config).spmm(W, m_src, m_tgt); + eckit_linalg_backend(config).spmm(W.host_matrix(), m_src, m_tgt); } void SparseMatrixMultiply::apply( diff --git a/src/tests/interpolation/test_interpolation_finite_element_cached.cc b/src/tests/interpolation/test_interpolation_finite_element_cached.cc index 99d9803c6..971ffd90d 100644 --- a/src/tests/interpolation/test_interpolation_finite_element_cached.cc +++ b/src/tests/interpolation/test_interpolation_finite_element_cached.cc @@ -106,7 +106,7 @@ CASE("extract cache, copy it, and move it for use") { set_field(field_source, grid_source, func); - eckit::linalg::SparseMatrix matrix = get_or_create_cache(grid_source, grid_target).matrix(); + atlas::linalg::SparseMatrix matrix = get_or_create_cache(grid_source, grid_target).matrix(); EXPECT(not matrix.empty()); @@ -133,7 +133,7 @@ CASE("extract cache, copy it, and pass non-owning pointer") { set_field(field_source, grid_source, func); - eckit::linalg::SparseMatrix matrix = get_or_create_cache(grid_source, grid_target).matrix(); + atlas::linalg::SparseMatrix matrix = get_or_create_cache(grid_source, grid_target).matrix(); EXPECT(not matrix.empty());