Skip to content

Commit

Permalink
MIR-678 improved weight matrix validation (exception throwing, added …
Browse files Browse the repository at this point in the history
…check for duplicates), fine grained method-specific matrix validation
  • Loading branch information
pmaciel committed Nov 19, 2024
1 parent a944ff7 commit 7d75aad
Show file tree
Hide file tree
Showing 16 changed files with 60 additions and 56 deletions.
5 changes: 1 addition & 4 deletions src/mir/caching/WeightCache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ void WeightCacheTraits::load(const eckit::CacheManagerBase& manager, value_type&
value_type tmp(matrix::MatrixLoaderFactory::build(manager.loader(), path));
w.swap(tmp);

static bool matrixValidate = eckit::Resource<bool>("$MIR_MATRIX_VALIDATE", false);
if (matrixValidate) {
w.validate("fromCache");
}
w.validate("fromCache"); // check matrix structure (only)
}


Expand Down
29 changes: 10 additions & 19 deletions src/mir/method/MethodWeighted.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ MethodWeighted::MethodWeighted(const param::MIRParametrisation& parametrisation)
parametrisation_.get("pole-displacement-in-degree", poleDisplacement_);
ASSERT(poleDisplacement_ >= 0);

matrixValidate_ = eckit::Resource<bool>("$MIR_MATRIX_VALIDATE", false);
matrixAssemble_ = parametrisation_.userParametrisation().has("filter");

std::string nonLinear = "missing-if-heaviest-missing";
Expand Down Expand Up @@ -165,14 +164,15 @@ void MethodWeighted::createMatrix(context::Context& ctx, const repres::Represent
const repres::Representation& out, WeightMatrix& W, const lsm::LandSeaMasks& masks,
const Cropping& /*cropping*/) const {
trace::ResourceUsage usage(std::string("MethodWeighted::createMatrix [") + name() + "]");
const auto checks = validateMatrixWeights();

computeMatrixWeights(ctx, in, out, W, validateMatrixWeights());
// matrix validation always happens after creation, because the matrix can/will be cached
computeMatrixWeights(ctx, in, out, W);
W.validate("computeMatrixWeights", checks);

if (masks.active() && masks.cacheable()) {
applyMasks(W, masks);
if (matrixValidate_) {
W.validate("applyMasks");
}
W.validate("applyMasks", checks);
}
}

Expand Down Expand Up @@ -265,9 +265,7 @@ const WeightMatrix& MethodWeighted::getMatrix(context::Context& ctx, const repre
// it will be cached in memory nevertheless
if (masks.active() && !masks.cacheable()) {
applyMasks(W, masks);
if (matrixValidate_) {
W.validate("applyMasks");
}
W.validate("applyMasks", validateMatrixWeights());
}


Expand Down Expand Up @@ -404,8 +402,8 @@ lsm::LandSeaMasks MethodWeighted::getMasks(const repres::Representation& in, con
}


bool MethodWeighted::validateMatrixWeights() const {
return true;
WeightMatrix::Check MethodWeighted::validateMatrixWeights() const {
return {};
}


Expand Down Expand Up @@ -494,9 +492,7 @@ void MethodWeighted::execute(context::Context& ctx, const repres::Representation
trace::Timer t(str.str());

if (n->treatment(A, M, B, field.values(i), missingValue)) {
if (matrixValidate_) {
M.validate(str.str().c_str());
}
M.validate(str.str().c_str(), validateMatrixWeights());
}
}

Expand Down Expand Up @@ -528,7 +524,7 @@ void MethodWeighted::execute(context::Context& ctx, const repres::Representation


void MethodWeighted::computeMatrixWeights(context::Context& ctx, const repres::Representation& in,
const repres::Representation& out, WeightMatrix& W, bool validate) const {
const repres::Representation& out, WeightMatrix& W) const {
auto timing(ctx.statistics().computeMatrixTimer());

if (in.sameAs(out) && !matrixAssemble_) {
Expand Down Expand Up @@ -565,11 +561,6 @@ void MethodWeighted::computeMatrixWeights(context::Context& ctx, const repres::R
W.swap(w);
}
}

// matrix validation always happens after creation, because the matrix can/will be cached
if (validate) {
W.validate("computeMatrixWeights");
}
}


Expand Down
5 changes: 2 additions & 3 deletions src/mir/method/MethodWeighted.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ class MethodWeighted : public Method {
std::unique_ptr<const reorder::Reorder> reorderRows_;
std::unique_ptr<const reorder::Reorder> reorderCols_;

bool matrixValidate_;
bool matrixAssemble_;

// -- Methods
Expand All @@ -127,10 +126,10 @@ class MethodWeighted : public Method {

virtual void applyMasks(WeightMatrix&, const lsm::LandSeaMasks&) const;
virtual lsm::LandSeaMasks getMasks(const repres::Representation& in, const repres::Representation& out) const;
virtual bool validateMatrixWeights() const;
virtual WeightMatrix::Check validateMatrixWeights() const;

void computeMatrixWeights(context::Context&, const repres::Representation& in, const repres::Representation& out,
WeightMatrix&, bool validate) const;
WeightMatrix&) const;
void createMatrix(context::Context&, const repres::Representation& in, const repres::Representation& out,
WeightMatrix&, const lsm::LandSeaMasks&, const Cropping&) const;

Expand Down
23 changes: 17 additions & 6 deletions src/mir/method/WeightMatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <sstream>
#include <unordered_set>

#include "eckit/config/Resource.h"
#include "eckit/types/FloatCompare.h"

#include "mir/util/Exceptions.h"
Expand Down Expand Up @@ -102,7 +103,12 @@ void WeightMatrix::cleanup(const double& pruneEpsilon) {
}


void WeightMatrix::validate(const char* when) const {
void WeightMatrix::validate(const char* when, Check check) const {
static bool matrixValidate = eckit::Resource<bool>("$MIR_MATRIX_VALIDATE", true);
if (!matrixValidate || (!check.duplicates && !check.bounds && !check.sum)) {
return;
}

constexpr size_t Nerrors = 10;
size_t errors = 0;

Expand All @@ -114,20 +120,25 @@ void WeightMatrix::validate(const char* when) const {
Scalar sum = 0.;
std::unordered_set<Size> cols;

bool check_bounds = true;
bool check_no_duplicates = true;
auto check_bounds = true;
auto check_duplicates = true;
for (auto it = begin(r); it != end(r); ++it) {
auto a = *it;
check_bounds &= eckit::types::is_approximately_greater_or_equal(a, 0.) &&
eckit::types::is_approximately_greater_or_equal(1., a);
sum += a;

check_no_duplicates &= cols.insert(it.col()).second;
check_duplicates &= cols.insert(it.col()).second;
}

auto check_sum = eckit::types::is_approximately_equal(sum, 0.) || eckit::types::is_approximately_equal(sum, 1.);

if (!check_bounds || !check_sum || !check_no_duplicates) {
// ignore checks as required
check_duplicates |= !check.duplicates;
check_bounds |= !check.bounds;
check_sum |= !check.sum;

if (!check_bounds || !check_sum || !check_duplicates) {
if (errors < Nerrors) {
what << sep << "row " << r << ": ";
const char* s = "";
Expand All @@ -142,7 +153,7 @@ void WeightMatrix::validate(const char* when) const {
s = ", ";
}

if (!check_no_duplicates) {
if (!check_duplicates) {
what << s << "duplicate indices";
s = ", ";
}
Expand Down
15 changes: 11 additions & 4 deletions src/mir/method/WeightMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ namespace mir::method {


class WeightMatrix final : public eckit::linalg::SparseMatrix {
public: // types
public:
using Triplet = eckit::linalg::Triplet;
using Scalar = eckit::linalg::Scalar;
using Size = eckit::linalg::Size;

public: // methods
struct Check {
bool duplicates = true;
bool bounds = true;
bool sum = true;
};

public:
WeightMatrix(SparseMatrix::Allocator* = nullptr);

WeightMatrix(const eckit::PathName&);
Expand All @@ -38,9 +44,10 @@ class WeightMatrix final : public eckit::linalg::SparseMatrix {

void cleanup(const double& pruneEpsilon = 0);

void validate(const char* when) const;
// Validate interpolation weights (default check matrix structure only)
void validate(const char* when, Check = {true, false, false}) const;

private: // members
private:
void print(std::ostream&) const;

friend std::ostream& operator<<(std::ostream& out, const WeightMatrix& m) {
Expand Down
4 changes: 2 additions & 2 deletions src/mir/method/gridbox/GridBoxMethod.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ void GridBoxMethod::print(std::ostream& out) const {
}


bool GridBoxMethod::validateMatrixWeights() const {
return false;
WeightMatrix::Check GridBoxMethod::validateMatrixWeights() const {
return {true, true, false};
}


Expand Down
2 changes: 1 addition & 1 deletion src/mir/method/gridbox/GridBoxMethod.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class GridBoxMethod : public MethodWeighted {
bool sameAs(const Method&) const override;
void json(eckit::JSON&) const override;
void print(std::ostream&) const override;
bool validateMatrixWeights() const override;
WeightMatrix::Check validateMatrixWeights() const override;
};


Expand Down
2 changes: 1 addition & 1 deletion src/mir/method/knn/KNearestNeighbours.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ void KNearestNeighbours::print(std::ostream& out) const {
}


bool KNearestNeighbours::validateMatrixWeights() const {
WeightMatrix::Check KNearestNeighbours::validateMatrixWeights() const {
return distanceWeighting().validateMatrixWeights();
}

Expand Down
2 changes: 1 addition & 1 deletion src/mir/method/knn/KNearestNeighbours.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class KNearestNeighbours : public MethodWeighted {
virtual const pick::Pick& pick() const = 0;
virtual const distance::DistanceWeighting& distanceWeighting() const = 0;

virtual bool validateMatrixWeights() const;
WeightMatrix::Check validateMatrixWeights() const override;
};


Expand Down
4 changes: 2 additions & 2 deletions src/mir/method/knn/distance/DistanceWeighting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ void DistanceWeightingFactory::list(std::ostream& out) {
}


bool DistanceWeighting::validateMatrixWeights() const {
return true;
WeightMatrix::Check DistanceWeighting::validateMatrixWeights() const {
return {};
}


Expand Down
2 changes: 1 addition & 1 deletion src/mir/method/knn/distance/DistanceWeighting.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class DistanceWeighting {

virtual void hash(eckit::MD5&) const = 0;

virtual bool validateMatrixWeights() const;
virtual WeightMatrix::Check validateMatrixWeights() const;

private:
virtual void json(eckit::JSON&) const = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/mir/method/knn/distance/PseudoLaplace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ void PseudoLaplace::hash(eckit::MD5& h) const {
}


bool PseudoLaplace::validateMatrixWeights() const {
WeightMatrix::Check PseudoLaplace::validateMatrixWeights() const {
// this method does not produce bounded interpolation weights
return false;
return {true, false, false};
}


Expand Down
2 changes: 1 addition & 1 deletion src/mir/method/knn/distance/PseudoLaplace.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct PseudoLaplace : DistanceWeighting {
void print(std::ostream&) const override;
void hash(eckit::MD5&) const override;

bool validateMatrixWeights() const override;
WeightMatrix::Check validateMatrixWeights() const override;
};


Expand Down
4 changes: 2 additions & 2 deletions src/mir/method/voronoi/VoronoiMethod.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ void VoronoiMethod::print(std::ostream& out) const {
}


bool VoronoiMethod::validateMatrixWeights() const {
return false;
WeightMatrix::Check VoronoiMethod::validateMatrixWeights() const {
return {true, true, false};
}


Expand Down
2 changes: 1 addition & 1 deletion src/mir/method/voronoi/VoronoiMethod.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class VoronoiMethod : public MethodWeighted {
bool sameAs(const Method&) const override;
void json(eckit::JSON&) const override;
void print(std::ostream&) const override;
bool validateMatrixWeights() const override;
WeightMatrix::Check validateMatrixWeights() const override;
const char* name() const override;

knn::pick::NClosestOrNearest pick_;
Expand Down
11 changes: 5 additions & 6 deletions tests/unit/weight_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ CASE("WeightMatrix::validate") {
const auto* when{"out-of-bounds"};
const std::string what{
"Invalid weight matrix (out-of-bounds): 1 row error, "
"row 2: weights out-of-bounds, "
"weights sum not 0 or 1 (sum=-0.1, 1-sum=1.1), contents: (2,2,-0.1)"};
"row 2: weights out-of-bounds, contents: (2,2,-0.1)"};

method::WeightMatrix W(3, 3);
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.}, {2, 2, -0.1}});

try {
W.validate(when);
W.validate(when, {false, true, false});
ASSERT(false);
}
catch (exception::InvalidWeightMatrix& e) {
Expand All @@ -51,7 +50,7 @@ CASE("WeightMatrix::validate") {
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.5}, {2, 2, 0.1}});

try {
W.validate(when);
W.validate(when, {false, false, true});
ASSERT(false);
}
catch (exception::InvalidWeightMatrix& e) {
Expand All @@ -70,7 +69,7 @@ CASE("WeightMatrix::validate") {
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.5}, {1, 1, 0.5}});

try {
W.validate(when);
W.validate(when, {true, false, false});
ASSERT(false);
}
catch (exception::InvalidWeightMatrix& e) {
Expand All @@ -90,7 +89,7 @@ CASE("WeightMatrix::validate") {
W.setFromTriplets({{0, 0, 0.5}, {0, 0, 0.5}, {1, 1, 0.5}});

try {
W.validate(when);
W.validate(when, {true, false, true});
ASSERT(false);
}
catch (exception::InvalidWeightMatrix& e) {
Expand Down

0 comments on commit 7d75aad

Please sign in to comment.