Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/bitmask missing values support. #22

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions src/mir/method/MethodWeighted.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ void MethodWeighted::createMatrix(context::Context& ctx, const repres::Represent
applyMasks(W, masks);
W.validate("applyMasks", checks);
}

bool bitmask;
parametrisation_.get("imm", bitmask);
if (bitmask){
std::vector<bool> vec_bitmask;
parametrisation_.get("imm-mask", vec_bitmask);
applyIMM(W,vec_bitmask);
if (matrixValidate_) {
W.validate("applyMasks");
}
}

}


Expand Down Expand Up @@ -208,6 +220,14 @@ MethodWeighted::CacheKeys MethodWeighted::getDiskAndMemoryCacheKeys(const repres
}
}

bool bitmask;
parametrisation_.get("imm", bitmask);
if (bitmask){
std::string missing_mask_key;
parametrisation_.get("imm-name", missing_mask_key);
memory_key += "_" + missing_mask_key;
disk_key += "_" + missing_mask_key;
}
return {disk_key, memory_key};
}

Expand Down Expand Up @@ -450,12 +470,16 @@ void MethodWeighted::execute(context::Context& ctx, const repres::Representation
const double missingValue = field.missingValue();

// matrix copy: run-time modifiable matrix is not cacheable
const bool matrixCopy = std::any_of(nonLinear_.begin(), nonLinear_.end(),
bool imm_active;
parametrisation_.get("imm", imm_active);
bool matrixCopy = false;
if (!imm_active) {
matrixCopy = std::any_of(nonLinear_.begin(), nonLinear_.end(),
[&field](const std::unique_ptr<const nonlinear::NonLinear>& n) {
return n->modifiesMatrix(field.hasMissing());
});


}
for (size_t i = 0; i < field.dimensions(); i++) {

std::ostringstream os;
Expand Down Expand Up @@ -621,6 +645,73 @@ void MethodWeighted::applyMasks(WeightMatrix& W, const lsm::LandSeaMasks& masks)
<< Log::Pretty(W.rows(), {"output point"}) << std::endl;
}

void MethodWeighted::applyIMM(WeightMatrix& W, const std::vector<bool>& imask) const {
trace::Timer timer("MethodWeighted::applyIMM");
auto& log = Log::debug();

log << "MethodWeighted::applyIMM" << std::endl;
log << "mask size=" << imask.size() << " == #cols=" << W.cols() << std::endl;

// Ensure that the mask size matches the number of columns in W
ASSERT(imask.size() == W.cols());

auto* data = const_cast<WeightMatrix::Scalar*>(W.data());
auto* outer = W.outer();
auto* inner = W.inner();
size_t fix = 0;

// Iterate over each row
for (WeightMatrix::Size r = 0; r < W.rows(); ++r) {
WeightMatrix::Size row_start = outer[r];
WeightMatrix::Size row_end = outer[r + 1];

size_t i_missing = row_start;
size_t N_missing = 0;
size_t N_entries = row_end - row_start;
double sum = 0.0;
double heaviest = -1.0;
bool heaviest_is_missing = false;

// Iterate over the entries in the current row
for (WeightMatrix::Size i = row_start; i < row_end; ++i) {
const bool miss = !imask[inner[i]];

if (miss) {
++N_missing;
i_missing = i;
} else {
sum += data[i];
}

if (heaviest < data[i]) {
heaviest = data[i];
heaviest_is_missing = miss;
}
}

// Weights redistribution: zero-weight all missing values, linear re-weighting for the others
if (N_missing > 0) {
++fix;
if (N_missing == N_entries || heaviest_is_missing || eckit::types::is_approximately_equal(sum, 0.0)) {
// All values are missing or the heaviest is missing; set only i_missing to 1
for (WeightMatrix::Size i = row_start; i < row_end; ++i) {
data[i] = (i == i_missing) ? 1.0 : 0.0;
}
} else {
// Scale non-missing entries so they sum to 1
const double factor = 1.0 / sum;
for (WeightMatrix::Size i = row_start; i < row_end; ++i) {
const bool miss = !imask[inner[i]];
data[i] = miss ? 0.0 : (factor * data[i]);
}
}
}
}

// Log the number of corrections made
log << "MethodWeighted: applyIMM corrected " << Log::Pretty(fix) << " out of "
<< Log::Pretty(W.rows(), {"Weight matrix rows"}) << std::endl;
}

void MethodWeighted::hash(eckit::MD5& md5) const {
md5.add(name());
Expand Down
1 change: 1 addition & 0 deletions src/mir/method/MethodWeighted.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class MethodWeighted : public Method {
const lsm::LandSeaMasks&) const;

virtual void applyMasks(WeightMatrix&, const lsm::LandSeaMasks&) const;
virtual void applyIMM(WeightMatrix&, const std::vector<bool>&) const;
virtual lsm::LandSeaMasks getMasks(const repres::Representation& in, const repres::Representation& out) const;
virtual WeightMatrix::Check validateMatrixWeights() const;

Expand Down
4 changes: 4 additions & 0 deletions src/mir/param/CombinedParametrisation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ bool CombinedParametrisation::get(const std::string& name, std::vector<double>&
return _get(name, value);
}

bool CombinedParametrisation::get(const std::string& name, std::vector<bool>& value) const {
return _get(name, value);
}

bool CombinedParametrisation::get(const std::string& name, std::vector<std::string>& value) const {
return _get(name, value);
}
Expand Down
1 change: 1 addition & 0 deletions src/mir/param/CombinedParametrisation.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class CombinedParametrisation : public MIRParametrisation {
bool get(const std::string& name, std::vector<long>& value) const override;
bool get(const std::string& name, std::vector<float>& value) const override;
bool get(const std::string& name, std::vector<double>& value) const override;
bool get(const std::string& name, std::vector<bool>& value) const override;
bool get(const std::string& name, std::vector<std::string>& value) const override;

// -- Class members
Expand Down
4 changes: 4 additions & 0 deletions src/mir/param/DefaultParametrisation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ DefaultParametrisation::DefaultParametrisation() {
set("lsm-weight-adjustment", 0.2);
set("lsm-value-threshold", 0.5);

set("imm", false);
set("imm-name", "bitmasked");
set("imm-mask", std::vector<bool>{});

set("spectral-order", "linear");

set("compare", "scalar");
Expand Down
5 changes: 5 additions & 0 deletions src/mir/param/MIRParametrisation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,9 @@ bool MIRParametrisation::get(const std::string& name, std::vector<long long>& va
return false;
}

bool MIRParametrisation::get(const std::string& name, std::vector<bool>& value) const {
std::ostringstream os;
os << "MIRParametrisation::get(const std::string& name, std::vector<bool>& value) not implemented for " << *this;
throw exception::SeriousBug(os.str());
}
} // namespace mir::param
1 change: 1 addition & 0 deletions src/mir/param/MIRParametrisation.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class MIRParametrisation : public eckit::Parametrisation {
bool get(const std::string& name, std::vector<float>& value) const override = 0;
bool get(const std::string& name, std::vector<double>& value) const override = 0;
bool get(const std::string& name, std::vector<std::string>& value) const override = 0;
virtual bool get(const std::string& name, std::vector<bool>& value) const ;

bool get(const std::string& name, size_t& value) const override;
bool get(const std::string& name, std::vector<size_t>& value) const override;
Expand Down
39 changes: 39 additions & 0 deletions src/mir/param/SimpleParametrisation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Setting {
virtual void get(const std::string& name, std::vector<float>& value) const = 0;
virtual void get(const std::string& name, std::vector<double>& value) const = 0;
virtual void get(const std::string& name, std::vector<std::string>& value) const = 0;
virtual void get(const std::string& name, std::vector<bool>& value) const = 0;

virtual bool matchAll(const std::string& name, const MIRParametrisation&) const = 0;
virtual bool matchAny(const std::string& name, const MIRParametrisation&) const = 0;
Expand Down Expand Up @@ -158,6 +159,12 @@ const char* TNamed<std::vector<std::string>>() {
return "vector<string>";
}

template <>
const char* TNamed<std::vector<bool>>() {
return "vector<bool>";
}



template <class T>
static void conversion_warning(const char* /*from*/, const char* /*to*/, const std::string& /*name*/,
Expand Down Expand Up @@ -226,6 +233,10 @@ class TSettings : public Setting {
throw exception::CannotConvert(TNamed<T>(), "vector<string>", name, value_);
}

void get(const std::string& name, std::vector<bool>& /*value*/) const {
throw exception::CannotConvert(TNamed<T>(), "vector<bool>", name, value_);
}

bool matchAll(const std::string& name, const MIRParametrisation& other) const override {
T value;
return other.get(name, value) && value_ == value;
Expand Down Expand Up @@ -269,6 +280,11 @@ void TSettings<std::vector<double>>::print(std::ostream& out) const {
_put(out, value_);
}

template <>
void TSettings<std::vector<bool>>::print(std::ostream& out) const {
_put(out, value_); // Uses the `_put` helper function for formatting
}


template <>
bool TSettings<std::vector<long>>::matchAll(const std::string& name, const MIRParametrisation& other) const {
Expand Down Expand Up @@ -443,6 +459,10 @@ void TSettings<std::vector<double>>::get(const std::string& /*name*/, std::vecto
value = value_;
}

template <>
void TSettings<std::vector<bool>>::get(const std::string& /*name*/, std::vector<bool>& value) const {
value = value_;
}

template <>
void TSettings<int>::get(const std::string& name, std::string& value) const {
Expand Down Expand Up @@ -555,6 +575,18 @@ void TSettings<std::vector<std::string>>::get(const std::string& name, std::stri
}
}

template <>
void TSettings<std::vector<bool>>::get(const std::string& name, std::string& value) const {
conversion_warning("vector<bool>", "string", name, value_);
value.clear();

const char* sep = "";
for (const auto& entry : value_) {
value += sep + std::to_string(entry);
sep = "/";
}
}


SimpleParametrisation::SimpleParametrisation() = default;

Expand Down Expand Up @@ -642,6 +674,9 @@ bool SimpleParametrisation::get(const std::string& name, std::vector<double>& va
return _get(name, value);
}

bool SimpleParametrisation::get(const std::string& name, std::vector<bool>& value) const {
return _get(name, value);
}

bool SimpleParametrisation::get(const std::string& /*name*/, std::vector<std::string>& /*value*/) const {
NOTIMP;
Expand Down Expand Up @@ -779,6 +814,10 @@ SimpleParametrisation& SimpleParametrisation::set(const std::string& name, const
_set(name, value);
return *this;
}
SimpleParametrisation& SimpleParametrisation::set(const std::string& name, const std::vector<bool>& value) {
_set(name, value);
return *this;
}


void SimpleParametrisation::print(std::ostream& out) const {
Expand Down
2 changes: 2 additions & 0 deletions src/mir/param/SimpleParametrisation.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class SimpleParametrisation : public MIRParametrisation {
SimpleParametrisation& set(const std::string& name, const std::vector<int>& value);
SimpleParametrisation& set(const std::string& name, const std::vector<long>& value);
SimpleParametrisation& set(const std::string& name, const std::vector<long long>& value);
SimpleParametrisation& set(const std::string& name, const std::vector<bool>& value);
SimpleParametrisation& set(const std::string& name, const std::vector<size_t>& value);
SimpleParametrisation& set(const std::string& name, const std::vector<float>& value);
SimpleParametrisation& set(const std::string& name, const std::vector<double>& value);
Expand Down Expand Up @@ -95,6 +96,7 @@ class SimpleParametrisation : public MIRParametrisation {

bool get(const std::string& name, std::vector<int>& value) const override;
bool get(const std::string& name, std::vector<long>& value) const override;
bool get(const std::string& name, std::vector<bool>& value) const override;
bool get(const std::string& name, std::vector<float>& value) const override;
bool get(const std::string& name, std::vector<double>& value) const override;
bool get(const std::string& name, std::vector<std::string>& value) const override;
Expand Down