Skip to content

Commit

Permalink
Merge pull request #5596 from wthrowe/imex/extract_point
Browse files Browse the repository at this point in the history
Add more extract_point overloads and add inverse
  • Loading branch information
nilsdeppe authored Oct 28, 2023
2 parents 85d0d5b + 2015d64 commit cf0b02b
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 61 deletions.
1 change: 1 addition & 0 deletions src/DataStructures/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ spectre_target_headers(
DynamicBuffer.hpp
DynamicMatrix.hpp
DynamicVector.hpp
ExtractPoint.hpp
FixedHashMap.hpp
FloatingPointType.hpp
IdPair.hpp
Expand Down
112 changes: 112 additions & 0 deletions src/DataStructures/ExtractPoint.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Distributed under the MIT License.
// See LICENSE.txt for details.

#pragma once

#include <cstddef>

#include "DataStructures/DataVector.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "DataStructures/Variables.hpp"
#include "Utilities/ErrorHandling/Assert.hpp"
#include "Utilities/Gsl.hpp"
#include "Utilities/TMPL.hpp"

/// \ingroup DataStructuresGroup
/// Copy a given index of each component of a `Tensor<DataVector>` or
/// `Variables<DataVector>` into a `Tensor<double>`, single point
/// `Tensor<DataVector>`, or single-point `Variables<DataVector>`.
///
/// \note There is no by-value overload extracting to a
/// `Tensor<DataVector>`. This is both for the practical reason that
/// it would be ambiguous with the `Tensor<double>` overload and
/// because allocating multiple `DataVector`s for the return type
/// would usually be very inefficient.
///
/// \see overwrite_point
/// @{
template <typename... Structure>
void extract_point(
const gsl::not_null<Tensor<double, Structure...>*> destination,
const Tensor<DataVector, Structure...>& source, const size_t index) {
for (size_t i = 0; i < destination->size(); ++i) {
(*destination)[i] = source[i][index];
}
}

template <typename... Structure>
Tensor<double, Structure...> extract_point(
const Tensor<DataVector, Structure...>& tensor, const size_t index) {
Tensor<double, Structure...> result;
extract_point(make_not_null(&result), tensor, index);
return result;
}

template <typename... Structure>
void extract_point(
const gsl::not_null<Tensor<DataVector, Structure...>*> destination,
const Tensor<DataVector, Structure...>& source, const size_t index) {
ASSERT(destination->begin()->size() == 1,
"Output tensor components have wrong size: "
<< destination->begin()->size());
for (size_t i = 0; i < destination->size(); ++i) {
(*destination)[i][0] = source[i][index];
}
}

template <typename... Tags>
void extract_point(
const gsl::not_null<Variables<tmpl::list<Tags...>>*> result,
const Variables<tmpl::list<Tags...>>& variables, const size_t index) {
result->initialize(1);
expand_pack((extract_point(
make_not_null(&get<Tags>(*result)), get<Tags>(variables), index), 0)...);
}

template <typename... Tags>
Variables<tmpl::list<Tags...>> extract_point(
const Variables<tmpl::list<Tags...>>& variables, const size_t index) {
Variables<tmpl::list<Tags...>> result(1);
extract_point(make_not_null(&result), variables, index);
return result;
}
/// @}

/// \ingroup DataStructuresGroup
/// Copy a `Tensor<double>`, single point `Tensor<DataVector>`, or
/// single-point `Variables<DataVector>` into the given index of each
/// component of a `Tensor<DataVector>` or `Variables<DataVector>`.
///
/// \see extract_point
/// @{
template <typename... Structure>
void overwrite_point(
const gsl::not_null<Tensor<DataVector, Structure...>*> destination,
const Tensor<double, Structure...>& source, const size_t index) {
for (size_t i = 0; i < destination->size(); ++i) {
(*destination)[i][index] = source[i];
}
}

template <typename... Structure>
void overwrite_point(
const gsl::not_null<Tensor<DataVector, Structure...>*> destination,
const Tensor<DataVector, Structure...>& source, const size_t index) {
ASSERT(source.begin()->size() == 1,
"Cannot overwrite with " << source.begin()->size() << " points.");
for (size_t i = 0; i < destination->size(); ++i) {
(*destination)[i][index] = source[i][0];
}
}

template <typename... Tags>
void overwrite_point(
const gsl::not_null<Variables<tmpl::list<Tags...>>*> destination,
const Variables<tmpl::list<Tags...>>& source, const size_t index) {
ASSERT(source.number_of_grid_points() == 1,
"Must overwrite with a single point.");
expand_pack((overwrite_point(make_not_null(&get<Tags>(*destination)),
extract_point(get<Tags>(source), 0), index),
0)...);
}
/// @}
1 change: 0 additions & 1 deletion src/DataStructures/Tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ spectre_target_headers(
${LIBRARY}
INCLUDE_DIRECTORY ${CMAKE_SOURCE_DIR}/src
HEADERS
ExtractPoint.hpp
Identity.hpp
IndexType.hpp
Metafunctions.hpp
Expand Down
22 changes: 0 additions & 22 deletions src/DataStructures/Tensor/ExtractPoint.hpp

This file was deleted.

2 changes: 1 addition & 1 deletion src/Evolution/DgSubcell/SetInterpolators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include <optional>
#include <utility>

#include "DataStructures/ExtractPoint.hpp"
#include "DataStructures/FixedHashMap.hpp"
#include "DataStructures/Tensor/ExtractPoint.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "Domain/Creators/Tags/Domain.hpp"
#include "Domain/Domain.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#include <pup.h> // IWYU pragma: keep

#include "DataStructures/DataVector.hpp"
#include "DataStructures/ExtractPoint.hpp"
#include "DataStructures/Tags/TempTensor.hpp"
#include "DataStructures/Tensor/EagerMath/DotProduct.hpp"
#include "DataStructures/Tensor/ExtractPoint.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "DataStructures/Variables.hpp"
#include "NumericalAlgorithms/RootFinding/TOMS748.hpp"
Expand Down
1 change: 1 addition & 0 deletions tests/Unit/DataStructures/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(LIBRARY_SOURCES
Test_DataVectorBinaryOperations.cpp
Test_DiagonalModalOperator.cpp
Test_DynamicBuffer.cpp
Test_ExtractPoint.cpp
Test_FixedHashMap.cpp
Test_FloatingPointType.cpp
Test_IdPair.cpp
Expand Down
1 change: 0 additions & 1 deletion tests/Unit/DataStructures/Tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
set(LIBRARY "Test_Tensor")

set(LIBRARY_SOURCES
Test_ExtractPoint.cpp
Test_Identity.cpp
Test_Metafunctions.cpp
Test_Slice.cpp
Expand Down
35 changes: 0 additions & 35 deletions tests/Unit/DataStructures/Tensor/Test_ExtractPoint.cpp

This file was deleted.

118 changes: 118 additions & 0 deletions tests/Unit/DataStructures/Test_ExtractPoint.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Distributed under the MIT License.
// See LICENSE.txt for details.

#include "Framework/TestingFramework.hpp"

#include <array>

#include "DataStructures/DataBox/Tag.hpp"
#include "DataStructures/DataVector.hpp"
#include "DataStructures/ExtractPoint.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "Utilities/Literals.hpp"
#include "Utilities/TMPL.hpp"

namespace {
struct ScalarTag : db::SimpleTag {
using type = Scalar<DataVector>;
};

struct TensorTag : db::SimpleTag {
using type = tnsr::ii<DataVector, 2>;
};

template <typename... Structure>
Tensor<DataVector, Structure...> promote_to_data_vectors(
const Tensor<double, Structure...>& tensor) {
Tensor<DataVector, Structure...> result(1_st);
for (size_t i = 0; i < result.size(); ++i) {
result[i] = tensor[i];
}
return result;
}
} // namespace

SPECTRE_TEST_CASE("Unit.DataStructures.ExtractPoint",
"[DataStructures][Unit]") {
// Test with owning Tensors as well as Variables to make sure there
// are no assumptions about memory layout in the tensor versions.
const Scalar<DataVector> scalar{DataVector{3.0, 4.0}};
tnsr::ii<DataVector, 2> tensor;
get<0, 0>(tensor) = DataVector{1.0, 2.0};
get<0, 1>(tensor) = DataVector{3.0, 4.0};
get<1, 1>(tensor) = DataVector{5.0, 6.0};
Variables<tmpl::list<ScalarTag, TensorTag>> variables(2);
get<ScalarTag>(variables) = scalar;
get<TensorTag>(variables) = tensor;

Scalar<DataVector> reconstructed_scalar{DataVector(2)};
tnsr::ii<DataVector, 2> reconstructed_tensor(DataVector(2));
Scalar<DataVector> reconstructed_scalar_from_dv{DataVector(2)};
tnsr::ii<DataVector, 2> reconstructed_tensor_from_dv(DataVector(2));
Variables<tmpl::list<ScalarTag, TensorTag>> reconstructed_variables(2);

const auto check_point = [&](const size_t index,
const double scalar_component,
const std::array<double, 3>& tensor_components) {
const Scalar<double> expected_scalar{scalar_component};
tnsr::ii<double, 2> expected_tensor;
get<0, 0>(expected_tensor) = tensor_components[0];
get<0, 1>(expected_tensor) = tensor_components[1];
get<1, 1>(expected_tensor) = tensor_components[2];
const Scalar<DataVector> expected_scalar_dv =
promote_to_data_vectors(expected_scalar);
const tnsr::ii<DataVector, 2> expected_tensor_dv =
promote_to_data_vectors(expected_tensor);
Variables<tmpl::list<ScalarTag, TensorTag>> expected_variables(1);
get<ScalarTag>(expected_variables) = expected_scalar_dv;
get<TensorTag>(expected_variables) = expected_tensor_dv;

{
Scalar<double> result{};
extract_point(make_not_null(&result), scalar, index);
CHECK(result == expected_scalar);
}
{
tnsr::ii<double, 2> result{};
extract_point(make_not_null(&result), tensor, index);
CHECK(result == expected_tensor);
}
CHECK(extract_point(scalar, index) == expected_scalar);
CHECK(extract_point(tensor, index) == expected_tensor);
{
Scalar<DataVector> result(1_st);
extract_point(make_not_null(&result), scalar, index);
CHECK(result == expected_scalar_dv);
}
{
tnsr::ii<DataVector, 2> result(1_st);
extract_point(make_not_null(&result), tensor, index);
CHECK(result == expected_tensor_dv);
}
{
Variables<tmpl::list<ScalarTag, TensorTag>> result(1);
extract_point(make_not_null(&result), variables, index);
CHECK(result == expected_variables);
}
CHECK(extract_point(variables, index) == expected_variables);

overwrite_point(make_not_null(&reconstructed_scalar), expected_scalar,
index);
overwrite_point(make_not_null(&reconstructed_tensor), expected_tensor,
index);
overwrite_point(make_not_null(&reconstructed_scalar_from_dv),
expected_scalar_dv, index);
overwrite_point(make_not_null(&reconstructed_tensor_from_dv),
expected_tensor_dv, index);
overwrite_point(make_not_null(&reconstructed_variables), expected_variables,
index);
};
check_point(0, 3.0, {{1.0, 3.0, 5.0}});
check_point(1, 4.0, {{2.0, 4.0, 6.0}});

CHECK(reconstructed_scalar == scalar);
CHECK(reconstructed_tensor == tensor);
CHECK(reconstructed_scalar_from_dv == scalar);
CHECK(reconstructed_tensor_from_dv == tensor);
CHECK(reconstructed_variables == variables);
}

0 comments on commit cf0b02b

Please sign in to comment.