Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more extract_point overloads and add inverse #5596

Merged
merged 1 commit into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DataStructures/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ spectre_target_headers(
DynamicBuffer.hpp
DynamicMatrix.hpp
DynamicVector.hpp
ExtractPoint.hpp
FixedHashMap.hpp
FloatingPointType.hpp
IdPair.hpp
Expand Down
112 changes: 112 additions & 0 deletions src/DataStructures/ExtractPoint.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Distributed under the MIT License.
// See LICENSE.txt for details.

#pragma once

#include <cstddef>

#include "DataStructures/DataVector.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "DataStructures/Variables.hpp"
#include "Utilities/ErrorHandling/Assert.hpp"
#include "Utilities/Gsl.hpp"
#include "Utilities/TMPL.hpp"

/// \ingroup DataStructuresGroup
/// Copy a given index of each component of a `Tensor<DataVector>` or
/// `Variables<DataVector>` into a `Tensor<double>`, single point
/// `Tensor<DataVector>`, or single-point `Variables<DataVector>`.
///
/// \note There is no by-value overload extracting to a
/// `Tensor<DataVector>`. This is both for the practical reason that
/// it would be ambiguous with the `Tensor<double>` overload and
/// because allocating multiple `DataVector`s for the return type
/// would usually be very inefficient.
///
/// \see overwrite_point
/// @{
template <typename... Structure>
void extract_point(
const gsl::not_null<Tensor<double, Structure...>*> destination,
const Tensor<DataVector, Structure...>& source, const size_t index) {
for (size_t i = 0; i < destination->size(); ++i) {
(*destination)[i] = source[i][index];
}
}

template <typename... Structure>
Tensor<double, Structure...> extract_point(
const Tensor<DataVector, Structure...>& tensor, const size_t index) {
Tensor<double, Structure...> result;
extract_point(make_not_null(&result), tensor, index);
return result;
}

template <typename... Structure>
void extract_point(
const gsl::not_null<Tensor<DataVector, Structure...>*> destination,
const Tensor<DataVector, Structure...>& source, const size_t index) {
ASSERT(destination->begin()->size() == 1,
"Output tensor components have wrong size: "
<< destination->begin()->size());
for (size_t i = 0; i < destination->size(); ++i) {
(*destination)[i][0] = source[i][index];
}
}

template <typename... Tags>
void extract_point(
const gsl::not_null<Variables<tmpl::list<Tags...>>*> result,
const Variables<tmpl::list<Tags...>>& variables, const size_t index) {
result->initialize(1);
expand_pack((extract_point(
make_not_null(&get<Tags>(*result)), get<Tags>(variables), index), 0)...);
}

template <typename... Tags>
Variables<tmpl::list<Tags...>> extract_point(
const Variables<tmpl::list<Tags...>>& variables, const size_t index) {
Variables<tmpl::list<Tags...>> result(1);
extract_point(make_not_null(&result), variables, index);
return result;
}
/// @}

/// \ingroup DataStructuresGroup
/// Copy a `Tensor<double>`, single point `Tensor<DataVector>`, or
/// single-point `Variables<DataVector>` into the given index of each
/// component of a `Tensor<DataVector>` or `Variables<DataVector>`.
///
/// \see extract_point
/// @{
template <typename... Structure>
void overwrite_point(
const gsl::not_null<Tensor<DataVector, Structure...>*> destination,
const Tensor<double, Structure...>& source, const size_t index) {
for (size_t i = 0; i < destination->size(); ++i) {
(*destination)[i][index] = source[i];
}
}

template <typename... Structure>
void overwrite_point(
const gsl::not_null<Tensor<DataVector, Structure...>*> destination,
const Tensor<DataVector, Structure...>& source, const size_t index) {
ASSERT(source.begin()->size() == 1,
"Cannot overwrite with " << source.begin()->size() << " points.");
for (size_t i = 0; i < destination->size(); ++i) {
(*destination)[i][index] = source[i][0];
}
}

template <typename... Tags>
void overwrite_point(
const gsl::not_null<Variables<tmpl::list<Tags...>>*> destination,
const Variables<tmpl::list<Tags...>>& source, const size_t index) {
ASSERT(source.number_of_grid_points() == 1,
"Must overwrite with a single point.");
expand_pack((overwrite_point(make_not_null(&get<Tags>(*destination)),
extract_point(get<Tags>(source), 0), index),
0)...);
}
/// @}
1 change: 0 additions & 1 deletion src/DataStructures/Tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ spectre_target_headers(
${LIBRARY}
INCLUDE_DIRECTORY ${CMAKE_SOURCE_DIR}/src
HEADERS
ExtractPoint.hpp
Identity.hpp
IndexType.hpp
Metafunctions.hpp
Expand Down
22 changes: 0 additions & 22 deletions src/DataStructures/Tensor/ExtractPoint.hpp

This file was deleted.

2 changes: 1 addition & 1 deletion src/Evolution/DgSubcell/SetInterpolators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include <optional>
#include <utility>

#include "DataStructures/ExtractPoint.hpp"
#include "DataStructures/FixedHashMap.hpp"
#include "DataStructures/Tensor/ExtractPoint.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "Domain/Creators/Tags/Domain.hpp"
#include "Domain/Domain.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#include <pup.h> // IWYU pragma: keep

#include "DataStructures/DataVector.hpp"
#include "DataStructures/ExtractPoint.hpp"
#include "DataStructures/Tags/TempTensor.hpp"
#include "DataStructures/Tensor/EagerMath/DotProduct.hpp"
#include "DataStructures/Tensor/ExtractPoint.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "DataStructures/Variables.hpp"
#include "NumericalAlgorithms/RootFinding/TOMS748.hpp"
Expand Down
1 change: 1 addition & 0 deletions tests/Unit/DataStructures/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(LIBRARY_SOURCES
Test_DataVectorBinaryOperations.cpp
Test_DiagonalModalOperator.cpp
Test_DynamicBuffer.cpp
Test_ExtractPoint.cpp
Test_FixedHashMap.cpp
Test_FloatingPointType.cpp
Test_IdPair.cpp
Expand Down
1 change: 0 additions & 1 deletion tests/Unit/DataStructures/Tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
set(LIBRARY "Test_Tensor")

set(LIBRARY_SOURCES
Test_ExtractPoint.cpp
Test_Identity.cpp
Test_Slice.cpp
Test_Tensor.cpp
Expand Down
35 changes: 0 additions & 35 deletions tests/Unit/DataStructures/Tensor/Test_ExtractPoint.cpp

This file was deleted.

118 changes: 118 additions & 0 deletions tests/Unit/DataStructures/Test_ExtractPoint.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Distributed under the MIT License.
// See LICENSE.txt for details.

#include "Framework/TestingFramework.hpp"

#include <array>

#include "DataStructures/DataBox/Tag.hpp"
#include "DataStructures/DataVector.hpp"
#include "DataStructures/ExtractPoint.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "Utilities/Literals.hpp"
#include "Utilities/TMPL.hpp"

namespace {
struct ScalarTag : db::SimpleTag {
using type = Scalar<DataVector>;
};

struct TensorTag : db::SimpleTag {
using type = tnsr::ii<DataVector, 2>;
};

template <typename... Structure>
Tensor<DataVector, Structure...> promote_to_data_vectors(
const Tensor<double, Structure...>& tensor) {
Tensor<DataVector, Structure...> result(1_st);
for (size_t i = 0; i < result.size(); ++i) {
result[i] = tensor[i];
}
return result;
}
} // namespace

SPECTRE_TEST_CASE("Unit.DataStructures.ExtractPoint",
"[DataStructures][Unit]") {
// Test with owning Tensors as well as Variables to make sure there
// are no assumptions about memory layout in the tensor versions.
const Scalar<DataVector> scalar{DataVector{3.0, 4.0}};
tnsr::ii<DataVector, 2> tensor;
get<0, 0>(tensor) = DataVector{1.0, 2.0};
get<0, 1>(tensor) = DataVector{3.0, 4.0};
get<1, 1>(tensor) = DataVector{5.0, 6.0};
Variables<tmpl::list<ScalarTag, TensorTag>> variables(2);
get<ScalarTag>(variables) = scalar;
get<TensorTag>(variables) = tensor;

Scalar<DataVector> reconstructed_scalar{DataVector(2)};
tnsr::ii<DataVector, 2> reconstructed_tensor(DataVector(2));
Scalar<DataVector> reconstructed_scalar_from_dv{DataVector(2)};
tnsr::ii<DataVector, 2> reconstructed_tensor_from_dv(DataVector(2));
Variables<tmpl::list<ScalarTag, TensorTag>> reconstructed_variables(2);

const auto check_point = [&](const size_t index,
const double scalar_component,
const std::array<double, 3>& tensor_components) {
const Scalar<double> expected_scalar{scalar_component};
tnsr::ii<double, 2> expected_tensor;
get<0, 0>(expected_tensor) = tensor_components[0];
get<0, 1>(expected_tensor) = tensor_components[1];
get<1, 1>(expected_tensor) = tensor_components[2];
const Scalar<DataVector> expected_scalar_dv =
promote_to_data_vectors(expected_scalar);
const tnsr::ii<DataVector, 2> expected_tensor_dv =
promote_to_data_vectors(expected_tensor);
Variables<tmpl::list<ScalarTag, TensorTag>> expected_variables(1);
get<ScalarTag>(expected_variables) = expected_scalar_dv;
get<TensorTag>(expected_variables) = expected_tensor_dv;

{
Scalar<double> result{};
extract_point(make_not_null(&result), scalar, index);
CHECK(result == expected_scalar);
}
{
tnsr::ii<double, 2> result{};
extract_point(make_not_null(&result), tensor, index);
CHECK(result == expected_tensor);
}
CHECK(extract_point(scalar, index) == expected_scalar);
CHECK(extract_point(tensor, index) == expected_tensor);
{
Scalar<DataVector> result(1_st);
extract_point(make_not_null(&result), scalar, index);
CHECK(result == expected_scalar_dv);
}
{
tnsr::ii<DataVector, 2> result(1_st);
extract_point(make_not_null(&result), tensor, index);
CHECK(result == expected_tensor_dv);
}
{
Variables<tmpl::list<ScalarTag, TensorTag>> result(1);
extract_point(make_not_null(&result), variables, index);
CHECK(result == expected_variables);
}
CHECK(extract_point(variables, index) == expected_variables);

overwrite_point(make_not_null(&reconstructed_scalar), expected_scalar,
index);
overwrite_point(make_not_null(&reconstructed_tensor), expected_tensor,
index);
overwrite_point(make_not_null(&reconstructed_scalar_from_dv),
expected_scalar_dv, index);
overwrite_point(make_not_null(&reconstructed_tensor_from_dv),
expected_tensor_dv, index);
overwrite_point(make_not_null(&reconstructed_variables), expected_variables,
index);
};
check_point(0, 3.0, {{1.0, 3.0, 5.0}});
check_point(1, 4.0, {{2.0, 4.0, 6.0}});

CHECK(reconstructed_scalar == scalar);
CHECK(reconstructed_tensor == tensor);
CHECK(reconstructed_scalar_from_dv == scalar);
CHECK(reconstructed_tensor_from_dv == tensor);
CHECK(reconstructed_variables == variables);
}
Loading