diff --git a/include/bmi/Bmi_Py_Adapter.hpp b/include/bmi/Bmi_Py_Adapter.hpp index ba1a823dc2..00175f8fda 100644 --- a/include/bmi/Bmi_Py_Adapter.hpp +++ b/include/bmi/Bmi_Py_Adapter.hpp @@ -623,6 +623,10 @@ namespace models { bmi_model->attr("set_value_at_indices")(name, index_array, src_array); } + auto model() { + return bmi_model; + } + protected: std::string model_name = "BMI Python model"; diff --git a/src/forcing/ForcingsEngineGriddedDataProvider.cpp b/src/forcing/ForcingsEngineGriddedDataProvider.cpp index 032e1f831b..f61eaa53fc 100644 --- a/src/forcing/ForcingsEngineGriddedDataProvider.cpp +++ b/src/forcing/ForcingsEngineGriddedDataProvider.cpp @@ -1,3 +1,4 @@ +#include #include namespace data_access { @@ -102,26 +103,23 @@ std::vector Provider::get_values( } const auto duration = std::chrono::seconds{selector.duration}; - - const auto start = clock_type::from_time_t(selector.init_time); + const auto start = clock_type::from_time_t(selector.init_time); assert(start >= time_begin_); - const auto end = start + duration; - assert(end <= time_end_); + auto until = (start - time_begin_) + duration; + if (until > time_end_ - time_begin_) { + until = time_end_ - time_begin_; + } std::vector values; - values.reserve(var_grid_mask_.size()); - std::cout << "Starting time: " << start.time_since_epoch().count() << "\n"; - for (auto current = start; current < end; current += time_step_, bmi_->UpdateUntil((current - start).count())) { - std::cout << "Current: " << current.time_since_epoch().count() << "\n"; - std::cout << "Updated to: " << (current - start).count() << "\n"; + values.resize(var_grid_mask_.size()); + while (std::chrono::seconds{std::lround(bmi_->GetCurrentTime())} < until) { // Get a span over the entire grid boost::span full = { static_cast(bmi_->GetValuePtr(variable)), var_grid_.rows * var_grid_.columns }; // Iterate row by row over the grid, masking the grid columns in each row. // For each row, we add the grid values to the masked grid values. for (auto r = var_grid_mask_.rmin; r < var_grid_mask_.rmax; ++r) { - std::cout << "At row " << r << "\n"; // Get the starting index of the current row within the full span // Equation: + ( * ) const std::size_t row_address = var_grid_mask_.cmin + (r * var_grid_.columns); @@ -132,15 +130,22 @@ std::vector Provider::get_values( // Get a span over the current row index on the underlying grid boost::span row = full.subspan(row_address, var_grid_mask_.columns()); + + // Print Row/Column Values + // std::cout << "row " << r << ": "; + // for (auto c = 0; c < var_grid_mask_.columns(); ++c) { + // std::cout << row[c] << ' '; + // } + // std::cout << '\n'; // Get a mutable span over the current row index in the masked values boost::span masked = { mask_address, var_grid_mask_.columns() }; // Add grid values to masked values std::transform(row.begin(), row.end(), masked.begin(), masked.begin(), std::plus{}); - - std::cout << std::endl; } + + bmi_->Update(); } if (m == ReSampleMethod::MEAN) { diff --git a/test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp b/test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp index f2f5c9c8b7..6c624a8c87 100644 --- a/test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp +++ b/test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp @@ -53,6 +53,11 @@ void TestFixture::SetUpTestSuite() /*time_end_seconds=*/TestFixture::time_end, /*mask=*/cat_11223_mask ); + + #if NGEN_WITH_MPI + auto comm = MPI_Comm_c2f(MPI_COMM_WORLD); + provider_->model()->model()->attr("set_value")("bmi_mpi_comm", py::array_t(comm)); + #endif } /** @@ -100,9 +105,18 @@ TEST_F(ForcingsEngineGriddedDataProviderTest, VariableAccess) auto selector = GriddedDataSelector{"PSFC", time_start, 3600, "seconds"}; auto result = provider_->get_values(selector, data_access::ReSampleMethod::SUM); - EXPECT_EQ(result.size(), 48980); - // EXPECT_NEAR(result, 99580.52, 1e-2); + EXPECT_EQ(result.size(), provider_->mask().size()); + + bool at_least_one = false; + for (auto v : result) { + if (v > 0) { + at_least_one = true; + break; + } + } + EXPECT_TRUE(at_least_one) << "All values of `result` are 0"; + // EXPECT_NEAR(result, 99580.52, 1e-2); // selector = CatchmentAggrDataSelector{"cat-11223", "LWDOWN", time_start + 3600, 3600, "seconds"}; // auto result2 = provider_->get_values(selector, data_access::ReSampleMethod::SUM); // ASSERT_GT(result2.size(), 0);