Skip to content

Commit

Permalink
Use atlas::linalg::SparseMatrix class instead of eckit::linalg::Spars…
Browse files Browse the repository at this point in the history
…eMatrix.
  • Loading branch information
l90lpa committed Nov 20, 2024
1 parent b5528ee commit 24321a5
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/atlas/interpolation/Cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "eckit/filesystem/PathName.h"
#include "eckit/io/Buffer.h"

#include "eckit/linalg/SparseMatrix.h"

#include "atlas/linalg/SparseMatrix.h"
#include "atlas/runtime/Exception.h"
#include "atlas/util/KDTree.h"

Expand Down Expand Up @@ -82,7 +82,7 @@ class Cache {

class MatrixCacheEntry : public InterpolationCacheEntry {
public:
using Matrix = eckit::linalg::SparseMatrix;
using Matrix = atlas::linalg::SparseMatrix;
~MatrixCacheEntry() override;
MatrixCacheEntry(const Matrix* matrix, const std::string& uid = ""): matrix_{matrix}, uid_(uid) {
ATLAS_ASSERT(matrix_ != nullptr);
Expand Down
4 changes: 2 additions & 2 deletions src/atlas/interpolation/method/Method.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

#include "atlas/interpolation/Cache.h"
#include "atlas/interpolation/NonLinear.h"
#include "atlas/linalg/SparseMatrix.h"
#include "atlas/util/Metadata.h"
#include "atlas/util/Object.h"
#include "eckit/config/Configuration.h"
#include "eckit/linalg/SparseMatrix.h"

namespace atlas {
class Field;
Expand Down Expand Up @@ -87,7 +87,7 @@ class Method : public util::Object {

using Triplet = eckit::linalg::Triplet;
using Triplets = std::vector<Triplet>;
using Matrix = eckit::linalg::SparseMatrix;
using Matrix = atlas::linalg::SparseMatrix;

static void normalise(Triplets& triplets);

Expand Down
4 changes: 2 additions & 2 deletions src/atlas/interpolation/method/binning/Binning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#include "atlas/interpolation/Interpolation.h"
#include "atlas/interpolation/method/binning/Binning.h"
#include "atlas/interpolation/method/MethodFactory.h"
#include "atlas/linalg/SparseMatrix.h"
#include "atlas/mesh.h"
#include "atlas/mesh/actions/GetCubedSphereNodalArea.h"
#include "atlas/runtime/Trace.h"

#include "eckit/config/LocalConfiguration.h"
#include "eckit/linalg/SparseMatrix.h"
#include "eckit/linalg/Triplet.h"
#include "eckit/mpi/Comm.h"

Expand Down Expand Up @@ -58,7 +58,7 @@ void Binning::do_setup(const FunctionSpace& source,

using Index = eckit::linalg::Index;
using Triplet = eckit::linalg::Triplet;
using SMatrix = eckit::linalg::SparseMatrix;
using SMatrix = atlas::linalg::SparseMatrix;

source_ = source;
target_ = target;
Expand Down
2 changes: 1 addition & 1 deletion src/atlas/interpolation/method/knn/GridBoxMaximum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void GridBoxMaximum::do_execute(const Field& source, Field& target, Metadata&) c

if (!matrixFree_) {
const Matrix& m = matrix();
Matrix::const_iterator k(m);
Matrix::const_iterator k = m.begin();

for (decltype(m.rows()) i = 0, j = 0; i < m.rows(); ++i) {
double max = std::numeric_limits<double>::lowest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ void ConservativeSphericalPolygonInterpolation::intersect_polygons(const CSPolyg
}
}

eckit::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_1st_order_matrix() {
atlas::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_1st_order_matrix() {
ATLAS_TRACE("ConservativeMethod::setup: build cons-1 interpolant matrix");
ATLAS_ASSERT(not matrix_free_);
Triplets triplets;
Expand Down Expand Up @@ -1358,7 +1358,7 @@ eckit::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_1
return Matrix(n_tpoints_, n_spoints_, triplets);
}

eckit::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_2nd_order_matrix() {
atlas::linalg::SparseMatrix ConservativeSphericalPolygonInterpolation::compute_2nd_order_matrix() {
ATLAS_TRACE("ConservativeMethod::setup: build cons-2 interpolant matrix");
ATLAS_ASSERT(not matrix_free_);
const auto& src_points_ = data_->src_points_;
Expand Down
15 changes: 12 additions & 3 deletions src/atlas/interpolation/nonlinear/Missing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ bool MissingIfAllMissing::executeT(NonLinear::Matrix& W, const Field& field) con
bool zeros = false;

Size i = 0;
Matrix::iterator it(W);
Matrix::iterator it = W.begin();
for (Size r = 0; r < W.rows(); ++r) {
const Matrix::iterator end = W.end(r);

Expand Down Expand Up @@ -128,6 +128,9 @@ bool MissingIfAllMissing::executeT(NonLinear::Matrix& W, const Field& field) con
modif = true;
}
}
if (modif) {
W.setDeviceNeedsUpdate(true);
}

if (zeros && missingValue.isnan()) {
W.prune(0.);
Expand Down Expand Up @@ -161,7 +164,7 @@ bool MissingIfAnyMissing::executeT(NonLinear::Matrix& W, const Field& field) con
bool zeros = false;

Size i = 0;
Matrix::iterator it(W);
Matrix::iterator it = W.begin();
for (Size r = 0; r < W.rows(); ++r) {
const Matrix::iterator end = W.end(r);

Expand Down Expand Up @@ -195,6 +198,9 @@ bool MissingIfAnyMissing::executeT(NonLinear::Matrix& W, const Field& field) con
modif = true;
}
}
if (modif) {
W.setDeviceNeedsUpdate(true);
}

if (zeros && missingValue.isnan()) {
W.prune(0.);
Expand Down Expand Up @@ -228,7 +234,7 @@ bool MissingIfHeaviestMissing::executeT(NonLinear::Matrix& W, const Field& field
bool zeros = false;

Size i = 0;
Matrix::iterator it(W);
Matrix::iterator it = W.begin();
for (Size r = 0; r < W.rows(); ++r) {
const Matrix::iterator end = W.end(r);

Expand Down Expand Up @@ -282,6 +288,9 @@ bool MissingIfHeaviestMissing::executeT(NonLinear::Matrix& W, const Field& field
modif = true;
}
}
if (modif) {
W.setDeviceNeedsUpdate(true);
}

if (zeros && missingValue.isnan()) {
W.prune(0.);
Expand Down
8 changes: 4 additions & 4 deletions src/atlas/interpolation/nonlinear/NonLinear.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#include <type_traits>

#include "eckit/config/Parametrisation.h"
#include "eckit/linalg/SparseMatrix.h"

#include "atlas/array.h"
#include "atlas/field/Field.h"
#include "atlas/linalg/SparseMatrix.h"
#include "atlas/runtime/Exception.h"
#include "atlas/util/Factory.h"
#include "atlas/util/ObjectHandle.h"
Expand All @@ -37,9 +37,9 @@ namespace nonlinear {
class NonLinear : public util::Object {
public:
using Config = eckit::Parametrisation;
using Matrix = eckit::linalg::SparseMatrix;
using Scalar = eckit::linalg::Scalar;
using Size = eckit::linalg::Size;
using Matrix = atlas::linalg::SparseMatrix;
using Scalar = atlas::linalg::SparseMatrix::Scalar;
using Size = atlas::linalg::SparseMatrix::Size;

/**
* @brief ctor
Expand Down
3 changes: 1 addition & 2 deletions src/atlas/linalg/sparse/SparseMatrixMultiply.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
#pragma once

#include "eckit/config/Configuration.h"
#include "eckit/linalg/SparseMatrix.h"

#include "atlas/linalg/Indexing.h"
#include "atlas/linalg/SparseMatrix.h"
#include "atlas/linalg/View.h"
#include "atlas/linalg/sparse/Backend.h"
#include "atlas/runtime/Exception.h"
Expand All @@ -22,7 +22,6 @@
namespace atlas {
namespace linalg {

using SparseMatrix = eckit::linalg::SparseMatrix;
using Configuration = eckit::Configuration;

template <typename Matrix, typename SourceView, typename TargetView>
Expand Down
4 changes: 2 additions & 2 deletions src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, cons
ATLAS_ASSERT(tgt.contiguous());
eckit::linalg::Vector v_src(src.data(), src.size());
eckit::linalg::Vector v_tgt(tgt.data(), tgt.size());
eckit_linalg_backend(config).spmv(W, v_src, v_tgt);
eckit_linalg_backend(config).spmv(W.host_matrix(), v_src, v_tgt);
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::multiply(
Expand All @@ -88,7 +88,7 @@ void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, cons
ATLAS_ASSERT(tgt.shape(1) >= W.rows());
eckit::linalg::Matrix m_src(src.data(), src.shape(1), src.shape(0));
eckit::linalg::Matrix m_tgt(tgt.data(), tgt.shape(1), tgt.shape(0));
eckit_linalg_backend(config).spmm(W, m_src, m_tgt);
eckit_linalg_backend(config).spmm(W.host_matrix(), m_src, m_tgt);
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double>::multiply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ CASE("extract cache, copy it, and move it for use") {

set_field(field_source, grid_source, func);

eckit::linalg::SparseMatrix matrix = get_or_create_cache(grid_source, grid_target).matrix();
atlas::linalg::SparseMatrix matrix = get_or_create_cache(grid_source, grid_target).matrix();

EXPECT(not matrix.empty());

Expand All @@ -133,7 +133,7 @@ CASE("extract cache, copy it, and pass non-owning pointer") {

set_field(field_source, grid_source, func);

eckit::linalg::SparseMatrix matrix = get_or_create_cache(grid_source, grid_target).matrix();
atlas::linalg::SparseMatrix matrix = get_or_create_cache(grid_source, grid_target).matrix();

EXPECT(not matrix.empty());

Expand Down

0 comments on commit 24321a5

Please sign in to comment.