From 7d75aad5a53bc1eca6e6ab8b9ad0cb3fad84b779 Mon Sep 17 00:00:00 2001 From: Pedro Maciel Date: Tue, 19 Nov 2024 07:31:07 +0000 Subject: [PATCH] MIR-678 improved weight matrix validation (exception throwing, added check for duplicates), fine grained method-specific matrix validation --- src/mir/caching/WeightCache.cc | 5 +--- src/mir/method/MethodWeighted.cc | 29 +++++++------------ src/mir/method/MethodWeighted.h | 5 ++-- src/mir/method/WeightMatrix.cc | 23 +++++++++++---- src/mir/method/WeightMatrix.h | 15 +++++++--- src/mir/method/gridbox/GridBoxMethod.cc | 4 +-- src/mir/method/gridbox/GridBoxMethod.h | 2 +- src/mir/method/knn/KNearestNeighbours.cc | 2 +- src/mir/method/knn/KNearestNeighbours.h | 2 +- .../method/knn/distance/DistanceWeighting.cc | 4 +-- .../method/knn/distance/DistanceWeighting.h | 2 +- src/mir/method/knn/distance/PseudoLaplace.cc | 4 +-- src/mir/method/knn/distance/PseudoLaplace.h | 2 +- src/mir/method/voronoi/VoronoiMethod.cc | 4 +-- src/mir/method/voronoi/VoronoiMethod.h | 2 +- tests/unit/weight_matrix.cc | 11 ++++--- 16 files changed, 60 insertions(+), 56 deletions(-) diff --git a/src/mir/caching/WeightCache.cc b/src/mir/caching/WeightCache.cc index 97acccc4f..a67508214 100644 --- a/src/mir/caching/WeightCache.cc +++ b/src/mir/caching/WeightCache.cc @@ -83,10 +83,7 @@ void WeightCacheTraits::load(const eckit::CacheManagerBase& manager, value_type& value_type tmp(matrix::MatrixLoaderFactory::build(manager.loader(), path)); w.swap(tmp); - static bool matrixValidate = eckit::Resource("$MIR_MATRIX_VALIDATE", false); - if (matrixValidate) { - w.validate("fromCache"); - } + w.validate("fromCache"); // check matrix structure (only) } diff --git a/src/mir/method/MethodWeighted.cc b/src/mir/method/MethodWeighted.cc index 34dfe2358..51be149a9 100644 --- a/src/mir/method/MethodWeighted.cc +++ b/src/mir/method/MethodWeighted.cc @@ -81,7 +81,6 @@ MethodWeighted::MethodWeighted(const param::MIRParametrisation& parametrisation) parametrisation_.get("pole-displacement-in-degree", poleDisplacement_); ASSERT(poleDisplacement_ >= 0); - matrixValidate_ = eckit::Resource("$MIR_MATRIX_VALIDATE", false); matrixAssemble_ = parametrisation_.userParametrisation().has("filter"); std::string nonLinear = "missing-if-heaviest-missing"; @@ -165,14 +164,15 @@ void MethodWeighted::createMatrix(context::Context& ctx, const repres::Represent const repres::Representation& out, WeightMatrix& W, const lsm::LandSeaMasks& masks, const Cropping& /*cropping*/) const { trace::ResourceUsage usage(std::string("MethodWeighted::createMatrix [") + name() + "]"); + const auto checks = validateMatrixWeights(); - computeMatrixWeights(ctx, in, out, W, validateMatrixWeights()); + // matrix validation always happens after creation, because the matrix can/will be cached + computeMatrixWeights(ctx, in, out, W); + W.validate("computeMatrixWeights", checks); if (masks.active() && masks.cacheable()) { applyMasks(W, masks); - if (matrixValidate_) { - W.validate("applyMasks"); - } + W.validate("applyMasks", checks); } } @@ -265,9 +265,7 @@ const WeightMatrix& MethodWeighted::getMatrix(context::Context& ctx, const repre // it will be cached in memory nevertheless if (masks.active() && !masks.cacheable()) { applyMasks(W, masks); - if (matrixValidate_) { - W.validate("applyMasks"); - } + W.validate("applyMasks", validateMatrixWeights()); } @@ -404,8 +402,8 @@ lsm::LandSeaMasks MethodWeighted::getMasks(const repres::Representation& in, con } -bool MethodWeighted::validateMatrixWeights() const { - return true; +WeightMatrix::Check MethodWeighted::validateMatrixWeights() const { + return {}; } @@ -494,9 +492,7 @@ void MethodWeighted::execute(context::Context& ctx, const repres::Representation trace::Timer t(str.str()); if (n->treatment(A, M, B, field.values(i), missingValue)) { - if (matrixValidate_) { - M.validate(str.str().c_str()); - } + M.validate(str.str().c_str(), validateMatrixWeights()); } } @@ -528,7 +524,7 @@ void MethodWeighted::execute(context::Context& ctx, const repres::Representation void MethodWeighted::computeMatrixWeights(context::Context& ctx, const repres::Representation& in, - const repres::Representation& out, WeightMatrix& W, bool validate) const { + const repres::Representation& out, WeightMatrix& W) const { auto timing(ctx.statistics().computeMatrixTimer()); if (in.sameAs(out) && !matrixAssemble_) { @@ -565,11 +561,6 @@ void MethodWeighted::computeMatrixWeights(context::Context& ctx, const repres::R W.swap(w); } } - - // matrix validation always happens after creation, because the matrix can/will be cached - if (validate) { - W.validate("computeMatrixWeights"); - } } diff --git a/src/mir/method/MethodWeighted.h b/src/mir/method/MethodWeighted.h index 8ff6e1556..0c81775fe 100644 --- a/src/mir/method/MethodWeighted.h +++ b/src/mir/method/MethodWeighted.h @@ -114,7 +114,6 @@ class MethodWeighted : public Method { std::unique_ptr reorderRows_; std::unique_ptr reorderCols_; - bool matrixValidate_; bool matrixAssemble_; // -- Methods @@ -127,10 +126,10 @@ class MethodWeighted : public Method { virtual void applyMasks(WeightMatrix&, const lsm::LandSeaMasks&) const; virtual lsm::LandSeaMasks getMasks(const repres::Representation& in, const repres::Representation& out) const; - virtual bool validateMatrixWeights() const; + virtual WeightMatrix::Check validateMatrixWeights() const; void computeMatrixWeights(context::Context&, const repres::Representation& in, const repres::Representation& out, - WeightMatrix&, bool validate) const; + WeightMatrix&) const; void createMatrix(context::Context&, const repres::Representation& in, const repres::Representation& out, WeightMatrix&, const lsm::LandSeaMasks&, const Cropping&) const; diff --git a/src/mir/method/WeightMatrix.cc b/src/mir/method/WeightMatrix.cc index 80822157d..3bb3ca2eb 100644 --- a/src/mir/method/WeightMatrix.cc +++ b/src/mir/method/WeightMatrix.cc @@ -16,6 +16,7 @@ #include #include +#include "eckit/config/Resource.h" #include "eckit/types/FloatCompare.h" #include "mir/util/Exceptions.h" @@ -102,7 +103,12 @@ void WeightMatrix::cleanup(const double& pruneEpsilon) { } -void WeightMatrix::validate(const char* when) const { +void WeightMatrix::validate(const char* when, Check check) const { + static bool matrixValidate = eckit::Resource("$MIR_MATRIX_VALIDATE", true); + if (!matrixValidate || (!check.duplicates && !check.bounds && !check.sum)) { + return; + } + constexpr size_t Nerrors = 10; size_t errors = 0; @@ -114,20 +120,25 @@ void WeightMatrix::validate(const char* when) const { Scalar sum = 0.; std::unordered_set cols; - bool check_bounds = true; - bool check_no_duplicates = true; + auto check_bounds = true; + auto check_duplicates = true; for (auto it = begin(r); it != end(r); ++it) { auto a = *it; check_bounds &= eckit::types::is_approximately_greater_or_equal(a, 0.) && eckit::types::is_approximately_greater_or_equal(1., a); sum += a; - check_no_duplicates &= cols.insert(it.col()).second; + check_duplicates &= cols.insert(it.col()).second; } auto check_sum = eckit::types::is_approximately_equal(sum, 0.) || eckit::types::is_approximately_equal(sum, 1.); - if (!check_bounds || !check_sum || !check_no_duplicates) { + // ignore checks as required + check_duplicates |= !check.duplicates; + check_bounds |= !check.bounds; + check_sum |= !check.sum; + + if (!check_bounds || !check_sum || !check_duplicates) { if (errors < Nerrors) { what << sep << "row " << r << ": "; const char* s = ""; @@ -142,7 +153,7 @@ void WeightMatrix::validate(const char* when) const { s = ", "; } - if (!check_no_duplicates) { + if (!check_duplicates) { what << s << "duplicate indices"; s = ", "; } diff --git a/src/mir/method/WeightMatrix.h b/src/mir/method/WeightMatrix.h index 7695c879d..908ba0f8b 100644 --- a/src/mir/method/WeightMatrix.h +++ b/src/mir/method/WeightMatrix.h @@ -22,12 +22,18 @@ namespace mir::method { class WeightMatrix final : public eckit::linalg::SparseMatrix { -public: // types +public: using Triplet = eckit::linalg::Triplet; using Scalar = eckit::linalg::Scalar; using Size = eckit::linalg::Size; -public: // methods + struct Check { + bool duplicates = true; + bool bounds = true; + bool sum = true; + }; + +public: WeightMatrix(SparseMatrix::Allocator* = nullptr); WeightMatrix(const eckit::PathName&); @@ -38,9 +44,10 @@ class WeightMatrix final : public eckit::linalg::SparseMatrix { void cleanup(const double& pruneEpsilon = 0); - void validate(const char* when) const; + // Validate interpolation weights (default check matrix structure only) + void validate(const char* when, Check = {true, false, false}) const; -private: // members +private: void print(std::ostream&) const; friend std::ostream& operator<<(std::ostream& out, const WeightMatrix& m) { diff --git a/src/mir/method/gridbox/GridBoxMethod.cc b/src/mir/method/gridbox/GridBoxMethod.cc index 4e917e3ba..938007914 100644 --- a/src/mir/method/gridbox/GridBoxMethod.cc +++ b/src/mir/method/gridbox/GridBoxMethod.cc @@ -45,8 +45,8 @@ void GridBoxMethod::print(std::ostream& out) const { } -bool GridBoxMethod::validateMatrixWeights() const { - return false; +WeightMatrix::Check GridBoxMethod::validateMatrixWeights() const { + return {true, true, false}; } diff --git a/src/mir/method/gridbox/GridBoxMethod.h b/src/mir/method/gridbox/GridBoxMethod.h index e36210ac2..5708a321c 100644 --- a/src/mir/method/gridbox/GridBoxMethod.h +++ b/src/mir/method/gridbox/GridBoxMethod.h @@ -27,7 +27,7 @@ class GridBoxMethod : public MethodWeighted { bool sameAs(const Method&) const override; void json(eckit::JSON&) const override; void print(std::ostream&) const override; - bool validateMatrixWeights() const override; + WeightMatrix::Check validateMatrixWeights() const override; }; diff --git a/src/mir/method/knn/KNearestNeighbours.cc b/src/mir/method/knn/KNearestNeighbours.cc index f08cf7a8a..531b48d31 100644 --- a/src/mir/method/knn/KNearestNeighbours.cc +++ b/src/mir/method/knn/KNearestNeighbours.cc @@ -162,7 +162,7 @@ void KNearestNeighbours::print(std::ostream& out) const { } -bool KNearestNeighbours::validateMatrixWeights() const { +WeightMatrix::Check KNearestNeighbours::validateMatrixWeights() const { return distanceWeighting().validateMatrixWeights(); } diff --git a/src/mir/method/knn/KNearestNeighbours.h b/src/mir/method/knn/KNearestNeighbours.h index 96e5ad69b..7da3e2e22 100644 --- a/src/mir/method/knn/KNearestNeighbours.h +++ b/src/mir/method/knn/KNearestNeighbours.h @@ -56,7 +56,7 @@ class KNearestNeighbours : public MethodWeighted { virtual const pick::Pick& pick() const = 0; virtual const distance::DistanceWeighting& distanceWeighting() const = 0; - virtual bool validateMatrixWeights() const; + WeightMatrix::Check validateMatrixWeights() const override; }; diff --git a/src/mir/method/knn/distance/DistanceWeighting.cc b/src/mir/method/knn/distance/DistanceWeighting.cc index 9a326ab1f..7dd3d2884 100644 --- a/src/mir/method/knn/distance/DistanceWeighting.cc +++ b/src/mir/method/knn/distance/DistanceWeighting.cc @@ -85,8 +85,8 @@ void DistanceWeightingFactory::list(std::ostream& out) { } -bool DistanceWeighting::validateMatrixWeights() const { - return true; +WeightMatrix::Check DistanceWeighting::validateMatrixWeights() const { + return {}; } diff --git a/src/mir/method/knn/distance/DistanceWeighting.h b/src/mir/method/knn/distance/DistanceWeighting.h index f75ba1766..2a3c85f88 100644 --- a/src/mir/method/knn/distance/DistanceWeighting.h +++ b/src/mir/method/knn/distance/DistanceWeighting.h @@ -47,7 +47,7 @@ class DistanceWeighting { virtual void hash(eckit::MD5&) const = 0; - virtual bool validateMatrixWeights() const; + virtual WeightMatrix::Check validateMatrixWeights() const; private: virtual void json(eckit::JSON&) const = 0; diff --git a/src/mir/method/knn/distance/PseudoLaplace.cc b/src/mir/method/knn/distance/PseudoLaplace.cc index 78acead34..b0f59bc02 100644 --- a/src/mir/method/knn/distance/PseudoLaplace.cc +++ b/src/mir/method/knn/distance/PseudoLaplace.cc @@ -127,9 +127,9 @@ void PseudoLaplace::hash(eckit::MD5& h) const { } -bool PseudoLaplace::validateMatrixWeights() const { +WeightMatrix::Check PseudoLaplace::validateMatrixWeights() const { // this method does not produce bounded interpolation weights - return false; + return {true, false, false}; } diff --git a/src/mir/method/knn/distance/PseudoLaplace.h b/src/mir/method/knn/distance/PseudoLaplace.h index 1d2b02974..4c87b4860 100644 --- a/src/mir/method/knn/distance/PseudoLaplace.h +++ b/src/mir/method/knn/distance/PseudoLaplace.h @@ -29,7 +29,7 @@ struct PseudoLaplace : DistanceWeighting { void print(std::ostream&) const override; void hash(eckit::MD5&) const override; - bool validateMatrixWeights() const override; + WeightMatrix::Check validateMatrixWeights() const override; }; diff --git a/src/mir/method/voronoi/VoronoiMethod.cc b/src/mir/method/voronoi/VoronoiMethod.cc index 4672cf89f..7fcdbba18 100644 --- a/src/mir/method/voronoi/VoronoiMethod.cc +++ b/src/mir/method/voronoi/VoronoiMethod.cc @@ -168,8 +168,8 @@ void VoronoiMethod::print(std::ostream& out) const { } -bool VoronoiMethod::validateMatrixWeights() const { - return false; +WeightMatrix::Check VoronoiMethod::validateMatrixWeights() const { + return {true, true, false}; } diff --git a/src/mir/method/voronoi/VoronoiMethod.h b/src/mir/method/voronoi/VoronoiMethod.h index 7a2dbd76c..0f13769dc 100644 --- a/src/mir/method/voronoi/VoronoiMethod.h +++ b/src/mir/method/voronoi/VoronoiMethod.h @@ -30,7 +30,7 @@ class VoronoiMethod : public MethodWeighted { bool sameAs(const Method&) const override; void json(eckit::JSON&) const override; void print(std::ostream&) const override; - bool validateMatrixWeights() const override; + WeightMatrix::Check validateMatrixWeights() const override; const char* name() const override; knn::pick::NClosestOrNearest pick_; diff --git a/tests/unit/weight_matrix.cc b/tests/unit/weight_matrix.cc index 02b75a435..9d252ec4f 100644 --- a/tests/unit/weight_matrix.cc +++ b/tests/unit/weight_matrix.cc @@ -24,14 +24,13 @@ CASE("WeightMatrix::validate") { const auto* when{"out-of-bounds"}; const std::string what{ "Invalid weight matrix (out-of-bounds): 1 row error, " - "row 2: weights out-of-bounds, " - "weights sum not 0 or 1 (sum=-0.1, 1-sum=1.1), contents: (2,2,-0.1)"}; + "row 2: weights out-of-bounds, contents: (2,2,-0.1)"}; method::WeightMatrix W(3, 3); W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.}, {2, 2, -0.1}}); try { - W.validate(when); + W.validate(when, {false, true, false}); ASSERT(false); } catch (exception::InvalidWeightMatrix& e) { @@ -51,7 +50,7 @@ CASE("WeightMatrix::validate") { W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.5}, {2, 2, 0.1}}); try { - W.validate(when); + W.validate(when, {false, false, true}); ASSERT(false); } catch (exception::InvalidWeightMatrix& e) { @@ -70,7 +69,7 @@ CASE("WeightMatrix::validate") { W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.5}, {1, 1, 0.5}}); try { - W.validate(when); + W.validate(when, {true, false, false}); ASSERT(false); } catch (exception::InvalidWeightMatrix& e) { @@ -90,7 +89,7 @@ CASE("WeightMatrix::validate") { W.setFromTriplets({{0, 0, 0.5}, {0, 0, 0.5}, {1, 1, 0.5}}); try { - W.validate(when); + W.validate(when, {true, false, true}); ASSERT(false); } catch (exception::InvalidWeightMatrix& e) {