diff --git a/src/mir/method/MethodWeighted.cc b/src/mir/method/MethodWeighted.cc index 51be149a9..7e1574e16 100644 --- a/src/mir/method/MethodWeighted.cc +++ b/src/mir/method/MethodWeighted.cc @@ -174,6 +174,18 @@ void MethodWeighted::createMatrix(context::Context& ctx, const repres::Represent applyMasks(W, masks); W.validate("applyMasks", checks); } + + bool bitmask; + parametrisation_.get("imm", bitmask); + if (bitmask){ + std::vector vec_bitmask; + parametrisation_.get("imm-mask", vec_bitmask); + applyIMM(W,vec_bitmask); + if (matrixValidate_) { + W.validate("applyMasks"); + } + } + } @@ -208,6 +220,14 @@ MethodWeighted::CacheKeys MethodWeighted::getDiskAndMemoryCacheKeys(const repres } } + bool bitmask; + parametrisation_.get("imm", bitmask); + if (bitmask){ + std::string missing_mask_key; + parametrisation_.get("imm-name", missing_mask_key); + memory_key += "_" + missing_mask_key; + disk_key += "_" + missing_mask_key; + } return {disk_key, memory_key}; } @@ -450,12 +470,16 @@ void MethodWeighted::execute(context::Context& ctx, const repres::Representation const double missingValue = field.missingValue(); // matrix copy: run-time modifiable matrix is not cacheable - const bool matrixCopy = std::any_of(nonLinear_.begin(), nonLinear_.end(), + bool imm_active; + parametrisation_.get("imm", imm_active); + bool matrixCopy = false; + if (!imm_active) { + matrixCopy = std::any_of(nonLinear_.begin(), nonLinear_.end(), [&field](const std::unique_ptr& n) { return n->modifiesMatrix(field.hasMissing()); }); - - + } + for (size_t i = 0; i < field.dimensions(); i++) { std::ostringstream os; @@ -621,6 +645,73 @@ void MethodWeighted::applyMasks(WeightMatrix& W, const lsm::LandSeaMasks& masks) << Log::Pretty(W.rows(), {"output point"}) << std::endl; } +void MethodWeighted::applyIMM(WeightMatrix& W, const std::vector& imask) const { + trace::Timer timer("MethodWeighted::applyIMM"); +auto& log = Log::debug(); + +log << "MethodWeighted::applyIMM" << std::endl; +log << "mask size=" << imask.size() << " == #cols=" << W.cols() << std::endl; + +// Ensure that the mask size matches the number of columns in W +ASSERT(imask.size() == W.cols()); + +auto* data = const_cast(W.data()); +auto* outer = W.outer(); +auto* inner = W.inner(); +size_t fix = 0; + +// Iterate over each row +for (WeightMatrix::Size r = 0; r < W.rows(); ++r) { + WeightMatrix::Size row_start = outer[r]; + WeightMatrix::Size row_end = outer[r + 1]; + + size_t i_missing = row_start; + size_t N_missing = 0; + size_t N_entries = row_end - row_start; + double sum = 0.0; + double heaviest = -1.0; + bool heaviest_is_missing = false; + + // Iterate over the entries in the current row + for (WeightMatrix::Size i = row_start; i < row_end; ++i) { + const bool miss = !imask[inner[i]]; + + if (miss) { + ++N_missing; + i_missing = i; + } else { + sum += data[i]; + } + + if (heaviest < data[i]) { + heaviest = data[i]; + heaviest_is_missing = miss; + } + } + + // Weights redistribution: zero-weight all missing values, linear re-weighting for the others + if (N_missing > 0) { + ++fix; + if (N_missing == N_entries || heaviest_is_missing || eckit::types::is_approximately_equal(sum, 0.0)) { + // All values are missing or the heaviest is missing; set only i_missing to 1 + for (WeightMatrix::Size i = row_start; i < row_end; ++i) { + data[i] = (i == i_missing) ? 1.0 : 0.0; + } + } else { + // Scale non-missing entries so they sum to 1 + const double factor = 1.0 / sum; + for (WeightMatrix::Size i = row_start; i < row_end; ++i) { + const bool miss = !imask[inner[i]]; + data[i] = miss ? 0.0 : (factor * data[i]); + } + } + } +} + +// Log the number of corrections made +log << "MethodWeighted: applyIMM corrected " << Log::Pretty(fix) << " out of " + << Log::Pretty(W.rows(), {"Weight matrix rows"}) << std::endl; +} void MethodWeighted::hash(eckit::MD5& md5) const { md5.add(name()); diff --git a/src/mir/method/MethodWeighted.h b/src/mir/method/MethodWeighted.h index 0c81775fe..1a6f4238a 100644 --- a/src/mir/method/MethodWeighted.h +++ b/src/mir/method/MethodWeighted.h @@ -125,6 +125,7 @@ class MethodWeighted : public Method { const lsm::LandSeaMasks&) const; virtual void applyMasks(WeightMatrix&, const lsm::LandSeaMasks&) const; + virtual void applyIMM(WeightMatrix&, const std::vector&) const; virtual lsm::LandSeaMasks getMasks(const repres::Representation& in, const repres::Representation& out) const; virtual WeightMatrix::Check validateMatrixWeights() const; diff --git a/src/mir/param/CombinedParametrisation.cc b/src/mir/param/CombinedParametrisation.cc index 4635e5a9f..653e27681 100644 --- a/src/mir/param/CombinedParametrisation.cc +++ b/src/mir/param/CombinedParametrisation.cc @@ -108,6 +108,10 @@ bool CombinedParametrisation::get(const std::string& name, std::vector& return _get(name, value); } +bool CombinedParametrisation::get(const std::string& name, std::vector& value) const { + return _get(name, value); +} + bool CombinedParametrisation::get(const std::string& name, std::vector& value) const { return _get(name, value); } diff --git a/src/mir/param/CombinedParametrisation.h b/src/mir/param/CombinedParametrisation.h index 884ec8fc6..13960181e 100644 --- a/src/mir/param/CombinedParametrisation.h +++ b/src/mir/param/CombinedParametrisation.h @@ -100,6 +100,7 @@ class CombinedParametrisation : public MIRParametrisation { bool get(const std::string& name, std::vector& value) const override; bool get(const std::string& name, std::vector& value) const override; bool get(const std::string& name, std::vector& value) const override; + bool get(const std::string& name, std::vector& value) const override; bool get(const std::string& name, std::vector& value) const override; // -- Class members diff --git a/src/mir/param/DefaultParametrisation.cc b/src/mir/param/DefaultParametrisation.cc index 8d76e4693..bc0d414eb 100644 --- a/src/mir/param/DefaultParametrisation.cc +++ b/src/mir/param/DefaultParametrisation.cc @@ -37,6 +37,10 @@ DefaultParametrisation::DefaultParametrisation() { set("lsm-weight-adjustment", 0.2); set("lsm-value-threshold", 0.5); + set("imm", false); + set("imm-name", "bitmasked"); + set("imm-mask", std::vector{}); + set("spectral-order", "linear"); set("compare", "scalar"); diff --git a/src/mir/param/MIRParametrisation.cc b/src/mir/param/MIRParametrisation.cc index 882088a72..566ec8efc 100644 --- a/src/mir/param/MIRParametrisation.cc +++ b/src/mir/param/MIRParametrisation.cc @@ -89,4 +89,9 @@ bool MIRParametrisation::get(const std::string& name, std::vector& va return false; } +bool MIRParametrisation::get(const std::string& name, std::vector& value) const { + std::ostringstream os; + os << "MIRParametrisation::get(const std::string& name, std::vector& value) not implemented for " << *this; + throw exception::SeriousBug(os.str()); + } } // namespace mir::param diff --git a/src/mir/param/MIRParametrisation.h b/src/mir/param/MIRParametrisation.h index ad8efc61d..b2d52efea 100644 --- a/src/mir/param/MIRParametrisation.h +++ b/src/mir/param/MIRParametrisation.h @@ -63,6 +63,7 @@ class MIRParametrisation : public eckit::Parametrisation { bool get(const std::string& name, std::vector& value) const override = 0; bool get(const std::string& name, std::vector& value) const override = 0; bool get(const std::string& name, std::vector& value) const override = 0; + virtual bool get(const std::string& name, std::vector& value) const ; bool get(const std::string& name, size_t& value) const override; bool get(const std::string& name, std::vector& value) const override; diff --git a/src/mir/param/SimpleParametrisation.cc b/src/mir/param/SimpleParametrisation.cc index 57177b3da..839f03171 100644 --- a/src/mir/param/SimpleParametrisation.cc +++ b/src/mir/param/SimpleParametrisation.cc @@ -54,6 +54,7 @@ class Setting { virtual void get(const std::string& name, std::vector& value) const = 0; virtual void get(const std::string& name, std::vector& value) const = 0; virtual void get(const std::string& name, std::vector& value) const = 0; + virtual void get(const std::string& name, std::vector& value) const = 0; virtual bool matchAll(const std::string& name, const MIRParametrisation&) const = 0; virtual bool matchAny(const std::string& name, const MIRParametrisation&) const = 0; @@ -158,6 +159,12 @@ const char* TNamed>() { return "vector"; } +template <> +const char* TNamed>() { + return "vector"; +} + + template static void conversion_warning(const char* /*from*/, const char* /*to*/, const std::string& /*name*/, @@ -226,6 +233,10 @@ class TSettings : public Setting { throw exception::CannotConvert(TNamed(), "vector", name, value_); } + void get(const std::string& name, std::vector& /*value*/) const { + throw exception::CannotConvert(TNamed(), "vector", name, value_); + } + bool matchAll(const std::string& name, const MIRParametrisation& other) const override { T value; return other.get(name, value) && value_ == value; @@ -269,6 +280,11 @@ void TSettings>::print(std::ostream& out) const { _put(out, value_); } +template <> +void TSettings>::print(std::ostream& out) const { + _put(out, value_); // Uses the `_put` helper function for formatting +} + template <> bool TSettings>::matchAll(const std::string& name, const MIRParametrisation& other) const { @@ -443,6 +459,10 @@ void TSettings>::get(const std::string& /*name*/, std::vecto value = value_; } +template <> +void TSettings>::get(const std::string& /*name*/, std::vector& value) const { + value = value_; +} template <> void TSettings::get(const std::string& name, std::string& value) const { @@ -555,6 +575,18 @@ void TSettings>::get(const std::string& name, std::stri } } +template <> +void TSettings>::get(const std::string& name, std::string& value) const { + conversion_warning("vector", "string", name, value_); + value.clear(); + + const char* sep = ""; + for (const auto& entry : value_) { + value += sep + std::to_string(entry); + sep = "/"; + } +} + SimpleParametrisation::SimpleParametrisation() = default; @@ -642,6 +674,9 @@ bool SimpleParametrisation::get(const std::string& name, std::vector& va return _get(name, value); } +bool SimpleParametrisation::get(const std::string& name, std::vector& value) const { + return _get(name, value); +} bool SimpleParametrisation::get(const std::string& /*name*/, std::vector& /*value*/) const { NOTIMP; @@ -779,6 +814,10 @@ SimpleParametrisation& SimpleParametrisation::set(const std::string& name, const _set(name, value); return *this; } +SimpleParametrisation& SimpleParametrisation::set(const std::string& name, const std::vector& value) { + _set(name, value); + return *this; +} void SimpleParametrisation::print(std::ostream& out) const { diff --git a/src/mir/param/SimpleParametrisation.h b/src/mir/param/SimpleParametrisation.h index a4548a6a7..98d509a0b 100644 --- a/src/mir/param/SimpleParametrisation.h +++ b/src/mir/param/SimpleParametrisation.h @@ -66,6 +66,7 @@ class SimpleParametrisation : public MIRParametrisation { SimpleParametrisation& set(const std::string& name, const std::vector& value); SimpleParametrisation& set(const std::string& name, const std::vector& value); SimpleParametrisation& set(const std::string& name, const std::vector& value); + SimpleParametrisation& set(const std::string& name, const std::vector& value); SimpleParametrisation& set(const std::string& name, const std::vector& value); SimpleParametrisation& set(const std::string& name, const std::vector& value); SimpleParametrisation& set(const std::string& name, const std::vector& value); @@ -95,6 +96,7 @@ class SimpleParametrisation : public MIRParametrisation { bool get(const std::string& name, std::vector& value) const override; bool get(const std::string& name, std::vector& value) const override; + bool get(const std::string& name, std::vector& value) const override; bool get(const std::string& name, std::vector& value) const override; bool get(const std::string& name, std::vector& value) const override; bool get(const std::string& name, std::vector& value) const override;