-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5596 from wthrowe/imex/extract_point
Add more extract_point overloads and add inverse
- Loading branch information
Showing
10 changed files
with
234 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
// Distributed under the MIT License. | ||
// See LICENSE.txt for details. | ||
|
||
#pragma once | ||
|
||
#include <cstddef> | ||
|
||
#include "DataStructures/DataVector.hpp" | ||
#include "DataStructures/Tensor/Tensor.hpp" | ||
#include "DataStructures/Variables.hpp" | ||
#include "Utilities/ErrorHandling/Assert.hpp" | ||
#include "Utilities/Gsl.hpp" | ||
#include "Utilities/TMPL.hpp" | ||
|
||
/// \ingroup DataStructuresGroup | ||
/// Copy a given index of each component of a `Tensor<DataVector>` or | ||
/// `Variables<DataVector>` into a `Tensor<double>`, single point | ||
/// `Tensor<DataVector>`, or single-point `Variables<DataVector>`. | ||
/// | ||
/// \note There is no by-value overload extracting to a | ||
/// `Tensor<DataVector>`. This is both for the practical reason that | ||
/// it would be ambiguous with the `Tensor<double>` overload and | ||
/// because allocating multiple `DataVector`s for the return type | ||
/// would usually be very inefficient. | ||
/// | ||
/// \see overwrite_point | ||
/// @{ | ||
template <typename... Structure> | ||
void extract_point( | ||
const gsl::not_null<Tensor<double, Structure...>*> destination, | ||
const Tensor<DataVector, Structure...>& source, const size_t index) { | ||
for (size_t i = 0; i < destination->size(); ++i) { | ||
(*destination)[i] = source[i][index]; | ||
} | ||
} | ||
|
||
template <typename... Structure> | ||
Tensor<double, Structure...> extract_point( | ||
const Tensor<DataVector, Structure...>& tensor, const size_t index) { | ||
Tensor<double, Structure...> result; | ||
extract_point(make_not_null(&result), tensor, index); | ||
return result; | ||
} | ||
|
||
template <typename... Structure> | ||
void extract_point( | ||
const gsl::not_null<Tensor<DataVector, Structure...>*> destination, | ||
const Tensor<DataVector, Structure...>& source, const size_t index) { | ||
ASSERT(destination->begin()->size() == 1, | ||
"Output tensor components have wrong size: " | ||
<< destination->begin()->size()); | ||
for (size_t i = 0; i < destination->size(); ++i) { | ||
(*destination)[i][0] = source[i][index]; | ||
} | ||
} | ||
|
||
template <typename... Tags> | ||
void extract_point( | ||
const gsl::not_null<Variables<tmpl::list<Tags...>>*> result, | ||
const Variables<tmpl::list<Tags...>>& variables, const size_t index) { | ||
result->initialize(1); | ||
expand_pack((extract_point( | ||
make_not_null(&get<Tags>(*result)), get<Tags>(variables), index), 0)...); | ||
} | ||
|
||
template <typename... Tags> | ||
Variables<tmpl::list<Tags...>> extract_point( | ||
const Variables<tmpl::list<Tags...>>& variables, const size_t index) { | ||
Variables<tmpl::list<Tags...>> result(1); | ||
extract_point(make_not_null(&result), variables, index); | ||
return result; | ||
} | ||
/// @} | ||
|
||
/// \ingroup DataStructuresGroup | ||
/// Copy a `Tensor<double>`, single point `Tensor<DataVector>`, or | ||
/// single-point `Variables<DataVector>` into the given index of each | ||
/// component of a `Tensor<DataVector>` or `Variables<DataVector>`. | ||
/// | ||
/// \see extract_point | ||
/// @{ | ||
template <typename... Structure> | ||
void overwrite_point( | ||
const gsl::not_null<Tensor<DataVector, Structure...>*> destination, | ||
const Tensor<double, Structure...>& source, const size_t index) { | ||
for (size_t i = 0; i < destination->size(); ++i) { | ||
(*destination)[i][index] = source[i]; | ||
} | ||
} | ||
|
||
template <typename... Structure> | ||
void overwrite_point( | ||
const gsl::not_null<Tensor<DataVector, Structure...>*> destination, | ||
const Tensor<DataVector, Structure...>& source, const size_t index) { | ||
ASSERT(source.begin()->size() == 1, | ||
"Cannot overwrite with " << source.begin()->size() << " points."); | ||
for (size_t i = 0; i < destination->size(); ++i) { | ||
(*destination)[i][index] = source[i][0]; | ||
} | ||
} | ||
|
||
template <typename... Tags> | ||
void overwrite_point( | ||
const gsl::not_null<Variables<tmpl::list<Tags...>>*> destination, | ||
const Variables<tmpl::list<Tags...>>& source, const size_t index) { | ||
ASSERT(source.number_of_grid_points() == 1, | ||
"Must overwrite with a single point."); | ||
expand_pack((overwrite_point(make_not_null(&get<Tags>(*destination)), | ||
extract_point(get<Tags>(source), 0), index), | ||
0)...); | ||
} | ||
/// @} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
// Distributed under the MIT License. | ||
// See LICENSE.txt for details. | ||
|
||
#include "Framework/TestingFramework.hpp" | ||
|
||
#include <array> | ||
|
||
#include "DataStructures/DataBox/Tag.hpp" | ||
#include "DataStructures/DataVector.hpp" | ||
#include "DataStructures/ExtractPoint.hpp" | ||
#include "DataStructures/Tensor/Tensor.hpp" | ||
#include "Utilities/Literals.hpp" | ||
#include "Utilities/TMPL.hpp" | ||
|
||
namespace { | ||
struct ScalarTag : db::SimpleTag { | ||
using type = Scalar<DataVector>; | ||
}; | ||
|
||
struct TensorTag : db::SimpleTag { | ||
using type = tnsr::ii<DataVector, 2>; | ||
}; | ||
|
||
template <typename... Structure> | ||
Tensor<DataVector, Structure...> promote_to_data_vectors( | ||
const Tensor<double, Structure...>& tensor) { | ||
Tensor<DataVector, Structure...> result(1_st); | ||
for (size_t i = 0; i < result.size(); ++i) { | ||
result[i] = tensor[i]; | ||
} | ||
return result; | ||
} | ||
} // namespace | ||
|
||
SPECTRE_TEST_CASE("Unit.DataStructures.ExtractPoint", | ||
"[DataStructures][Unit]") { | ||
// Test with owning Tensors as well as Variables to make sure there | ||
// are no assumptions about memory layout in the tensor versions. | ||
const Scalar<DataVector> scalar{DataVector{3.0, 4.0}}; | ||
tnsr::ii<DataVector, 2> tensor; | ||
get<0, 0>(tensor) = DataVector{1.0, 2.0}; | ||
get<0, 1>(tensor) = DataVector{3.0, 4.0}; | ||
get<1, 1>(tensor) = DataVector{5.0, 6.0}; | ||
Variables<tmpl::list<ScalarTag, TensorTag>> variables(2); | ||
get<ScalarTag>(variables) = scalar; | ||
get<TensorTag>(variables) = tensor; | ||
|
||
Scalar<DataVector> reconstructed_scalar{DataVector(2)}; | ||
tnsr::ii<DataVector, 2> reconstructed_tensor(DataVector(2)); | ||
Scalar<DataVector> reconstructed_scalar_from_dv{DataVector(2)}; | ||
tnsr::ii<DataVector, 2> reconstructed_tensor_from_dv(DataVector(2)); | ||
Variables<tmpl::list<ScalarTag, TensorTag>> reconstructed_variables(2); | ||
|
||
const auto check_point = [&](const size_t index, | ||
const double scalar_component, | ||
const std::array<double, 3>& tensor_components) { | ||
const Scalar<double> expected_scalar{scalar_component}; | ||
tnsr::ii<double, 2> expected_tensor; | ||
get<0, 0>(expected_tensor) = tensor_components[0]; | ||
get<0, 1>(expected_tensor) = tensor_components[1]; | ||
get<1, 1>(expected_tensor) = tensor_components[2]; | ||
const Scalar<DataVector> expected_scalar_dv = | ||
promote_to_data_vectors(expected_scalar); | ||
const tnsr::ii<DataVector, 2> expected_tensor_dv = | ||
promote_to_data_vectors(expected_tensor); | ||
Variables<tmpl::list<ScalarTag, TensorTag>> expected_variables(1); | ||
get<ScalarTag>(expected_variables) = expected_scalar_dv; | ||
get<TensorTag>(expected_variables) = expected_tensor_dv; | ||
|
||
{ | ||
Scalar<double> result{}; | ||
extract_point(make_not_null(&result), scalar, index); | ||
CHECK(result == expected_scalar); | ||
} | ||
{ | ||
tnsr::ii<double, 2> result{}; | ||
extract_point(make_not_null(&result), tensor, index); | ||
CHECK(result == expected_tensor); | ||
} | ||
CHECK(extract_point(scalar, index) == expected_scalar); | ||
CHECK(extract_point(tensor, index) == expected_tensor); | ||
{ | ||
Scalar<DataVector> result(1_st); | ||
extract_point(make_not_null(&result), scalar, index); | ||
CHECK(result == expected_scalar_dv); | ||
} | ||
{ | ||
tnsr::ii<DataVector, 2> result(1_st); | ||
extract_point(make_not_null(&result), tensor, index); | ||
CHECK(result == expected_tensor_dv); | ||
} | ||
{ | ||
Variables<tmpl::list<ScalarTag, TensorTag>> result(1); | ||
extract_point(make_not_null(&result), variables, index); | ||
CHECK(result == expected_variables); | ||
} | ||
CHECK(extract_point(variables, index) == expected_variables); | ||
|
||
overwrite_point(make_not_null(&reconstructed_scalar), expected_scalar, | ||
index); | ||
overwrite_point(make_not_null(&reconstructed_tensor), expected_tensor, | ||
index); | ||
overwrite_point(make_not_null(&reconstructed_scalar_from_dv), | ||
expected_scalar_dv, index); | ||
overwrite_point(make_not_null(&reconstructed_tensor_from_dv), | ||
expected_tensor_dv, index); | ||
overwrite_point(make_not_null(&reconstructed_variables), expected_variables, | ||
index); | ||
}; | ||
check_point(0, 3.0, {{1.0, 3.0, 5.0}}); | ||
check_point(1, 4.0, {{2.0, 4.0, 6.0}}); | ||
|
||
CHECK(reconstructed_scalar == scalar); | ||
CHECK(reconstructed_tensor == tensor); | ||
CHECK(reconstructed_scalar_from_dv == scalar); | ||
CHECK(reconstructed_tensor_from_dv == tensor); | ||
CHECK(reconstructed_variables == variables); | ||
} |