Skip to content

Commit

Permalink
fix(ForcingsEngineGriddedDataProvider): handle segfault by setting MP…
Browse files Browse the repository at this point in the history
…I comm; fix timing
  • Loading branch information
program-- committed Sep 9, 2024
1 parent 04f4c9e commit 51bfd05
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 14 deletions.
4 changes: 4 additions & 0 deletions include/bmi/Bmi_Py_Adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,10 @@ namespace models {
bmi_model->attr("set_value_at_indices")(name, index_array, src_array);
}

auto model() {
return bmi_model;
}

protected:
std::string model_name = "BMI Python model";

Expand Down
29 changes: 17 additions & 12 deletions src/forcing/ForcingsEngineGriddedDataProvider.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <chrono>
#include <forcing/ForcingsEngineGriddedDataProvider.hpp>

namespace data_access {
Expand Down Expand Up @@ -102,26 +103,23 @@ std::vector<Provider::data_type> Provider::get_values(
}

const auto duration = std::chrono::seconds{selector.duration};

const auto start = clock_type::from_time_t(selector.init_time);
const auto start = clock_type::from_time_t(selector.init_time);
assert(start >= time_begin_);

const auto end = start + duration;
assert(end <= time_end_);
auto until = (start - time_begin_) + duration;
if (until > time_end_ - time_begin_) {
until = time_end_ - time_begin_;
}

std::vector<double> values;
values.reserve(var_grid_mask_.size());
std::cout << "Starting time: " << start.time_since_epoch().count() << "\n";
for (auto current = start; current < end; current += time_step_, bmi_->UpdateUntil((current - start).count())) {
std::cout << "Current: " << current.time_since_epoch().count() << "\n";
std::cout << "Updated to: " << (current - start).count() << "\n";
values.resize(var_grid_mask_.size());
while (std::chrono::seconds{std::lround(bmi_->GetCurrentTime())} < until) {
// Get a span over the entire grid
boost::span<const double> full = { static_cast<double*>(bmi_->GetValuePtr(variable)), var_grid_.rows * var_grid_.columns };

// Iterate row by row over the grid, masking the grid columns in each row.
// For each row, we add the grid values to the masked grid values.
for (auto r = var_grid_mask_.rmin; r < var_grid_mask_.rmax; ++r) {
std::cout << "At row " << r << "\n";
// Get the starting index of the current row within the full span
// Equation: <starting column offset> + (<row offset> * <row size>)
const std::size_t row_address = var_grid_mask_.cmin + (r * var_grid_.columns);
Expand All @@ -132,15 +130,22 @@ std::vector<Provider::data_type> Provider::get_values(

// Get a span over the current row index on the underlying grid
boost::span<const double> row = full.subspan(row_address, var_grid_mask_.columns());

// Print Row/Column Values
// std::cout << "row " << r << ": ";
// for (auto c = 0; c < var_grid_mask_.columns(); ++c) {
// std::cout << row[c] << ' ';
// }
// std::cout << '\n';

// Get a mutable span over the current row index in the masked values
boost::span<double> masked = { mask_address, var_grid_mask_.columns() };

// Add grid values to masked values
std::transform(row.begin(), row.end(), masked.begin(), masked.begin(), std::plus<double>{});

std::cout << std::endl;
}

bmi_->Update();
}

if (m == ReSampleMethod::MEAN) {
Expand Down
18 changes: 16 additions & 2 deletions test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ void TestFixture::SetUpTestSuite()
/*time_end_seconds=*/TestFixture::time_end,
/*mask=*/cat_11223_mask
);

#if NGEN_WITH_MPI
auto comm = MPI_Comm_c2f(MPI_COMM_WORLD);
provider_->model()->model()->attr("set_value")("bmi_mpi_comm", py::array_t<MPI_Fint>(comm));
#endif
}

/**
Expand Down Expand Up @@ -100,9 +105,18 @@ TEST_F(ForcingsEngineGriddedDataProviderTest, VariableAccess)

auto selector = GriddedDataSelector{"PSFC", time_start, 3600, "seconds"};
auto result = provider_->get_values(selector, data_access::ReSampleMethod::SUM);
EXPECT_EQ(result.size(), 48980);
// EXPECT_NEAR(result, 99580.52, 1e-2);
EXPECT_EQ(result.size(), provider_->mask().size());

bool at_least_one = false;
for (auto v : result) {
if (v > 0) {
at_least_one = true;
break;
}
}
EXPECT_TRUE(at_least_one) << "All values of `result` are 0";

// EXPECT_NEAR(result, 99580.52, 1e-2);
// selector = CatchmentAggrDataSelector{"cat-11223", "LWDOWN", time_start + 3600, 3600, "seconds"};
// auto result2 = provider_->get_values(selector, data_access::ReSampleMethod::SUM);
// ASSERT_GT(result2.size(), 0);
Expand Down

0 comments on commit 51bfd05

Please sign in to comment.