-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add GridNodeIndex and GridData for MPM (#22295)
- Loading branch information
1 parent
fe6caf8
commit b95c19c
Showing
3 changed files
with
300 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
#pragma once | ||
|
||
#include <cstdint> | ||
|
||
#include "drake/common/bit_cast.h" | ||
#include "drake/common/drake_assert.h" | ||
#include "drake/common/drake_copyable.h" | ||
#include "drake/common/eigen_types.h" | ||
|
||
namespace drake { | ||
namespace multibody { | ||
namespace mpm { | ||
namespace internal { | ||
|
||
/* This class is a lightweight wrapper around an integer type (`int32_t` or | ||
`int64_t`) used to differentiate between active grid node indices, inactive | ||
states, and a special flag. An IndexOrFlag can be in exactly one of the | ||
following states: | ||
1. Active index: A non-negative integer representing the index of a grid node. | ||
2. Inactive state (the default state): neither an index nor a flag. | ||
3. The `flag` state: A special state used to mark grid nodes for deferred | ||
processing. The flag state is intended as a marker that must only be set | ||
from the inactive state (not from an active index) to maintain consistency. | ||
Transitions between states are as follows: | ||
- Any state can become inactive. | ||
- The inactive state can become flag state. | ||
- The inactive state can become active (with a non-negative index). | ||
- Active cannot directly become flag state (must go inactive first). | ||
An IndexOrFlag object is guaranteed to have size equal to its template | ||
parameter T. In addition, the inactive state is guaranteed to be represented | ||
with the value -1. | ||
@tparam T The integer type for the index. Must be `int32_t` or `int64_t`. */ | ||
template <typename T> | ||
class IndexOrFlag { | ||
public: | ||
DRAKE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN(IndexOrFlag); | ||
|
||
static_assert(std::is_same_v<T, int32_t> || std::is_same_v<T, int64_t>, | ||
"T must be int32_t or int64_t."); | ||
|
||
/* Default constructor initializes the object to the inactive state. */ | ||
constexpr IndexOrFlag() = default; | ||
|
||
/* Constructor for an active index. | ||
@pre index >= 0 */ | ||
explicit constexpr IndexOrFlag(T index) { set_index(index); } | ||
|
||
/* Sets the index to the given value, which must be non-negative, thereby | ||
making `this` active. | ||
@pre index >= 0 */ | ||
void set_index(T index) { | ||
DRAKE_ASSERT(index >= 0); | ||
value_ = index; | ||
} | ||
|
||
/* Sets `this` to the inactive state. */ | ||
void set_inactive() { value_ = kInactive; } | ||
|
||
/* Sets `this` to the flag state. | ||
@pre !is_index() (i.e., must currently be inactive) */ | ||
void set_flag() { | ||
DRAKE_ASSERT(!is_index()); | ||
value_ = kFlag; | ||
} | ||
|
||
/* Returns true iff `this` is an active index (i.e., a non-negative integer). | ||
*/ | ||
constexpr bool is_index() const { return value_ >= 0; } | ||
|
||
/* Returns true iff `this` is in the inactive state. */ | ||
constexpr bool is_inactive() const { return value_ == kInactive; } | ||
|
||
/* Returns true iff `this` is in the flag state. */ | ||
constexpr bool is_flag() const { return value_ == kFlag; } | ||
|
||
/* Returns the index value. | ||
@pre is_index() == true; */ | ||
constexpr T index() const { | ||
DRAKE_ASSERT(is_index()); | ||
return value_; | ||
} | ||
|
||
private: | ||
/* Note that the enum values are not arbitrary; kInactive must be -1 as in the | ||
class documentation. */ | ||
enum : T { kInactive = -1, kFlag = -2 }; | ||
/* We encode all information in `value_` to satisfy the size requirement laid | ||
out in the class documentation. */ | ||
T value_{kInactive}; | ||
}; | ||
|
||
/* GridData stores data at a single grid node of SparseGrid. | ||
It contains the mass and velocity of the node along with a scratch space for | ||
temporary storage and an index to the node. | ||
It's important to be conscious of the size of GridData since the MPM algorithm | ||
is usually memory-bound. We carefully pack GridData to be a power of 2 to work | ||
with SPGrid, which automatically packs the data to a power of 2. | ||
@tparam T double or float. */ | ||
template <typename T> | ||
struct GridData { | ||
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>, | ||
"T must be float or double."); | ||
|
||
/* Resets `this` GridData to its default state where all floating point values | ||
are set to NAN and the index is inactive. */ | ||
void reset() { *this = {}; } | ||
|
||
/* Returns true iff `this` GridData is bit-wise equal to `other`. */ | ||
bool operator==(const GridData<T>& other) const { | ||
return std::memcmp(this, &other, sizeof(GridData<T>)) == 0; | ||
} | ||
|
||
Vector3<T> v{Vector3<T>::Constant(nan_with_all_bits_set())}; | ||
T m{nan_with_all_bits_set()}; | ||
Vector3<T> scratch{Vector3<T>::Constant(nan_with_all_bits_set())}; | ||
std::conditional_t<std::is_same_v<T, float>, IndexOrFlag<int32_t>, | ||
IndexOrFlag<int64_t>> | ||
index_or_flag{}; | ||
|
||
private: | ||
/* Returns a floating point NaN value with all bits set to one. This choice | ||
makes the reset() function more efficient. In particlar, it allows the | ||
generated machine code to memset all bits to 1 instead of handling each field | ||
individually. */ | ||
static T nan_with_all_bits_set() { | ||
using IntType = | ||
std::conditional_t<std::is_same_v<T, float>, int32_t, int64_t>; | ||
constexpr IntType kAllBitsOn = -1; | ||
return drake::internal::bit_cast<T>(kAllBitsOn); | ||
} | ||
}; | ||
|
||
/* With T = float, GridData is expected to be 32 bytes. With T = double, | ||
GridData is expected to be 64 bytes. We enforce these sizes at compile time | ||
with static_assert, so that if future changes to this code, compiler alignment, | ||
or Eigen alignment rules cause a size shift, it will be caught early. */ | ||
static_assert(sizeof(GridData<float>) == 32, | ||
"Unexpected size for GridData<float>."); | ||
static_assert(sizeof(GridData<double>) == 64, | ||
"Unexpected size for GridData<double>."); | ||
|
||
} // namespace internal | ||
} // namespace mpm | ||
} // namespace multibody | ||
} // namespace drake |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
#include "drake/multibody/mpm/grid_data.h" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
namespace drake { | ||
namespace multibody { | ||
namespace mpm { | ||
namespace internal { | ||
namespace { | ||
|
||
using IndexTypes = ::testing::Types<int32_t, int64_t>; | ||
|
||
template <typename T> | ||
class IndexOrFlagTest : public ::testing::Test {}; | ||
|
||
TYPED_TEST_SUITE(IndexOrFlagTest, IndexTypes); | ||
|
||
TYPED_TEST(IndexOrFlagTest, Basic) { | ||
using T = TypeParam; | ||
IndexOrFlag<T> dut; | ||
EXPECT_FALSE(dut.is_index()); | ||
EXPECT_FALSE(dut.is_flag()); | ||
EXPECT_TRUE(dut.is_inactive()); | ||
|
||
dut.set_index(123); | ||
EXPECT_EQ(dut.index(), 123); | ||
dut.set_inactive(); | ||
EXPECT_TRUE(dut.is_inactive()); | ||
dut.set_flag(); | ||
EXPECT_TRUE(dut.is_flag()); | ||
/* Setting flag twice is allowed. */ | ||
dut.set_flag(); | ||
EXPECT_TRUE(dut.is_flag()); | ||
} | ||
|
||
TYPED_TEST(IndexOrFlagTest, StateTransition) { | ||
using T = TypeParam; | ||
IndexOrFlag<T> dut(123); | ||
EXPECT_TRUE(dut.is_index()); | ||
EXPECT_FALSE(dut.is_flag()); | ||
EXPECT_FALSE(dut.is_inactive()); | ||
|
||
/* Active -> Inactive */ | ||
dut.set_inactive(); | ||
EXPECT_FALSE(dut.is_index()); | ||
EXPECT_FALSE(dut.is_flag()); | ||
EXPECT_TRUE(dut.is_inactive()); | ||
|
||
/* Inactive -> Flag */ | ||
dut.set_flag(); | ||
EXPECT_FALSE(dut.is_index()); | ||
EXPECT_TRUE(dut.is_flag()); | ||
EXPECT_FALSE(dut.is_inactive()); | ||
|
||
/* Flag -> Inactive */ | ||
dut.set_inactive(); | ||
EXPECT_FALSE(dut.is_index()); | ||
EXPECT_FALSE(dut.is_flag()); | ||
EXPECT_TRUE(dut.is_inactive()); | ||
|
||
/* Inactive -> Active */ | ||
dut.set_index(123); | ||
EXPECT_TRUE(dut.is_index()); | ||
EXPECT_FALSE(dut.is_flag()); | ||
EXPECT_FALSE(dut.is_inactive()); | ||
EXPECT_EQ(dut.index(), 123); | ||
|
||
/* Additional scenario: Flag -> Active */ | ||
IndexOrFlag<T> another_dut; | ||
another_dut.set_flag(); | ||
another_dut.set_index(123); | ||
EXPECT_TRUE(another_dut.is_index()); | ||
EXPECT_EQ(another_dut.index(), 123); | ||
EXPECT_FALSE(another_dut.is_flag()); | ||
EXPECT_FALSE(another_dut.is_inactive()); | ||
} | ||
|
||
using FloatingPointTypes = ::testing::Types<float, double>; | ||
|
||
template <typename T> | ||
class GridDataTest : public ::testing::Test {}; | ||
|
||
TYPED_TEST_SUITE(GridDataTest, FloatingPointTypes); | ||
|
||
TYPED_TEST(GridDataTest, Reset) { | ||
using T = TypeParam; | ||
GridData<T> data; | ||
data.index_or_flag.set_index(123); | ||
data.scratch = Vector3<T>::Ones(); | ||
data.v = Vector3<T>::Ones(); | ||
data.m = 1; | ||
|
||
data.reset(); | ||
EXPECT_TRUE(data.index_or_flag.is_inactive()); | ||
EXPECT_NE(data.scratch, data.scratch); | ||
EXPECT_NE(data.v, data.v); | ||
EXPECT_TRUE(std::isnan(data.m)); | ||
} | ||
|
||
TYPED_TEST(GridDataTest, Equality) { | ||
using T = TypeParam; | ||
GridData<T> data1; | ||
data1.index_or_flag.set_index(123); | ||
data1.scratch = Vector3<T>::Ones(); | ||
data1.v = Vector3<T>::Ones(); | ||
data1.m = 1; | ||
|
||
GridData<T> data2; | ||
data2.index_or_flag.set_index(123); | ||
data2.scratch = Vector3<T>::Ones(); | ||
data2.v = Vector3<T>::Ones(); | ||
data2.m = 1; | ||
|
||
EXPECT_EQ(data1, data2); | ||
data2.index_or_flag.set_inactive(); | ||
EXPECT_NE(data1, data2); | ||
data1.index_or_flag.set_inactive(); | ||
EXPECT_EQ(data1, data2); | ||
data1.reset(); | ||
EXPECT_NE(data1, data2); | ||
data2.reset(); | ||
EXPECT_EQ(data1, data2); | ||
} | ||
|
||
} // namespace | ||
} // namespace internal | ||
} // namespace mpm | ||
} // namespace multibody | ||
} // namespace drake |