Skip to content

Commit

Permalink
Add GridNodeIndex and GridData for MPM (#22295)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchenhan-tri authored Dec 13, 2024
1 parent fe6caf8 commit b95c19c
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 0 deletions.
19 changes: 19 additions & 0 deletions multibody/mpm/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ drake_cc_package_library(
visibility = ["//visibility:public"],
deps = [
":bspline_weights",
":grid_data",
],
)

Expand All @@ -32,6 +33,17 @@ drake_cc_library(
],
)

drake_cc_library(
name = "grid_data",
hdrs = [
"grid_data.h",
],
deps = [
"//common:bit_cast",
"//common:essential",
],
)

# TODO(xuchenhan-tri): when we enable SPGrid in our releases, we also need to
# install its license file in drake/tools/workspace/BUILD.bazel.

Expand All @@ -56,6 +68,13 @@ drake_cc_googletest(
],
)

drake_cc_googletest(
name = "grid_data_test",
deps = [
":grid_data",
],
)

drake_cc_googletest_linux_only(
name = "spgrid_test",
deps = [
Expand Down
152 changes: 152 additions & 0 deletions multibody/mpm/grid_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#pragma once

#include <cstdint>

#include "drake/common/bit_cast.h"
#include "drake/common/drake_assert.h"
#include "drake/common/drake_copyable.h"
#include "drake/common/eigen_types.h"

namespace drake {
namespace multibody {
namespace mpm {
namespace internal {

/* This class is a lightweight wrapper around an integer type (`int32_t` or
`int64_t`) used to differentiate between active grid node indices, inactive
states, and a special flag. An IndexOrFlag can be in exactly one of the
following states:
1. Active index: A non-negative integer representing the index of a grid node.
2. Inactive state (the default state): neither an index nor a flag.
3. The `flag` state: A special state used to mark grid nodes for deferred
processing. The flag state is intended as a marker that must only be set
from the inactive state (not from an active index) to maintain consistency.
Transitions between states are as follows:
- Any state can become inactive.
- The inactive state can become flag state.
- The inactive state can become active (with a non-negative index).
- Active cannot directly become flag state (must go inactive first).
An IndexOrFlag object is guaranteed to have size equal to its template
parameter T. In addition, the inactive state is guaranteed to be represented
with the value -1.
@tparam T The integer type for the index. Must be `int32_t` or `int64_t`. */
template <typename T>
class IndexOrFlag {
public:
DRAKE_DEFAULT_COPY_AND_MOVE_AND_ASSIGN(IndexOrFlag);

static_assert(std::is_same_v<T, int32_t> || std::is_same_v<T, int64_t>,
"T must be int32_t or int64_t.");

/* Default constructor initializes the object to the inactive state. */
constexpr IndexOrFlag() = default;

/* Constructor for an active index.
@pre index >= 0 */
explicit constexpr IndexOrFlag(T index) { set_index(index); }

/* Sets the index to the given value, which must be non-negative, thereby
making `this` active.
@pre index >= 0 */
void set_index(T index) {
DRAKE_ASSERT(index >= 0);
value_ = index;
}

/* Sets `this` to the inactive state. */
void set_inactive() { value_ = kInactive; }

/* Sets `this` to the flag state.
@pre !is_index() (i.e., must currently be inactive) */
void set_flag() {
DRAKE_ASSERT(!is_index());
value_ = kFlag;
}

/* Returns true iff `this` is an active index (i.e., a non-negative integer).
*/
constexpr bool is_index() const { return value_ >= 0; }

/* Returns true iff `this` is in the inactive state. */
constexpr bool is_inactive() const { return value_ == kInactive; }

/* Returns true iff `this` is in the flag state. */
constexpr bool is_flag() const { return value_ == kFlag; }

/* Returns the index value.
@pre is_index() == true; */
constexpr T index() const {
DRAKE_ASSERT(is_index());
return value_;
}

private:
/* Note that the enum values are not arbitrary; kInactive must be -1 as in the
class documentation. */
enum : T { kInactive = -1, kFlag = -2 };
/* We encode all information in `value_` to satisfy the size requirement laid
out in the class documentation. */
T value_{kInactive};
};

/* GridData stores data at a single grid node of SparseGrid.
It contains the mass and velocity of the node along with a scratch space for
temporary storage and an index to the node.
It's important to be conscious of the size of GridData since the MPM algorithm
is usually memory-bound. We carefully pack GridData to be a power of 2 to work
with SPGrid, which automatically packs the data to a power of 2.
@tparam T double or float. */
template <typename T>
struct GridData {
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"T must be float or double.");

/* Resets `this` GridData to its default state where all floating point values
are set to NAN and the index is inactive. */
void reset() { *this = {}; }

/* Returns true iff `this` GridData is bit-wise equal to `other`. */
bool operator==(const GridData<T>& other) const {
return std::memcmp(this, &other, sizeof(GridData<T>)) == 0;
}

Vector3<T> v{Vector3<T>::Constant(nan_with_all_bits_set())};
T m{nan_with_all_bits_set()};
Vector3<T> scratch{Vector3<T>::Constant(nan_with_all_bits_set())};
std::conditional_t<std::is_same_v<T, float>, IndexOrFlag<int32_t>,
IndexOrFlag<int64_t>>
index_or_flag{};

private:
/* Returns a floating point NaN value with all bits set to one. This choice
makes the reset() function more efficient. In particlar, it allows the
generated machine code to memset all bits to 1 instead of handling each field
individually. */
static T nan_with_all_bits_set() {
using IntType =
std::conditional_t<std::is_same_v<T, float>, int32_t, int64_t>;
constexpr IntType kAllBitsOn = -1;
return drake::internal::bit_cast<T>(kAllBitsOn);
}
};

/* With T = float, GridData is expected to be 32 bytes. With T = double,
GridData is expected to be 64 bytes. We enforce these sizes at compile time
with static_assert, so that if future changes to this code, compiler alignment,
or Eigen alignment rules cause a size shift, it will be caught early. */
static_assert(sizeof(GridData<float>) == 32,
"Unexpected size for GridData<float>.");
static_assert(sizeof(GridData<double>) == 64,
"Unexpected size for GridData<double>.");

} // namespace internal
} // namespace mpm
} // namespace multibody
} // namespace drake
129 changes: 129 additions & 0 deletions multibody/mpm/test/grid_data_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include "drake/multibody/mpm/grid_data.h"

#include <gtest/gtest.h>

namespace drake {
namespace multibody {
namespace mpm {
namespace internal {
namespace {

using IndexTypes = ::testing::Types<int32_t, int64_t>;

template <typename T>
class IndexOrFlagTest : public ::testing::Test {};

TYPED_TEST_SUITE(IndexOrFlagTest, IndexTypes);

TYPED_TEST(IndexOrFlagTest, Basic) {
using T = TypeParam;
IndexOrFlag<T> dut;
EXPECT_FALSE(dut.is_index());
EXPECT_FALSE(dut.is_flag());
EXPECT_TRUE(dut.is_inactive());

dut.set_index(123);
EXPECT_EQ(dut.index(), 123);
dut.set_inactive();
EXPECT_TRUE(dut.is_inactive());
dut.set_flag();
EXPECT_TRUE(dut.is_flag());
/* Setting flag twice is allowed. */
dut.set_flag();
EXPECT_TRUE(dut.is_flag());
}

TYPED_TEST(IndexOrFlagTest, StateTransition) {
using T = TypeParam;
IndexOrFlag<T> dut(123);
EXPECT_TRUE(dut.is_index());
EXPECT_FALSE(dut.is_flag());
EXPECT_FALSE(dut.is_inactive());

/* Active -> Inactive */
dut.set_inactive();
EXPECT_FALSE(dut.is_index());
EXPECT_FALSE(dut.is_flag());
EXPECT_TRUE(dut.is_inactive());

/* Inactive -> Flag */
dut.set_flag();
EXPECT_FALSE(dut.is_index());
EXPECT_TRUE(dut.is_flag());
EXPECT_FALSE(dut.is_inactive());

/* Flag -> Inactive */
dut.set_inactive();
EXPECT_FALSE(dut.is_index());
EXPECT_FALSE(dut.is_flag());
EXPECT_TRUE(dut.is_inactive());

/* Inactive -> Active */
dut.set_index(123);
EXPECT_TRUE(dut.is_index());
EXPECT_FALSE(dut.is_flag());
EXPECT_FALSE(dut.is_inactive());
EXPECT_EQ(dut.index(), 123);

/* Additional scenario: Flag -> Active */
IndexOrFlag<T> another_dut;
another_dut.set_flag();
another_dut.set_index(123);
EXPECT_TRUE(another_dut.is_index());
EXPECT_EQ(another_dut.index(), 123);
EXPECT_FALSE(another_dut.is_flag());
EXPECT_FALSE(another_dut.is_inactive());
}

using FloatingPointTypes = ::testing::Types<float, double>;

template <typename T>
class GridDataTest : public ::testing::Test {};

TYPED_TEST_SUITE(GridDataTest, FloatingPointTypes);

TYPED_TEST(GridDataTest, Reset) {
using T = TypeParam;
GridData<T> data;
data.index_or_flag.set_index(123);
data.scratch = Vector3<T>::Ones();
data.v = Vector3<T>::Ones();
data.m = 1;

data.reset();
EXPECT_TRUE(data.index_or_flag.is_inactive());
EXPECT_NE(data.scratch, data.scratch);
EXPECT_NE(data.v, data.v);
EXPECT_TRUE(std::isnan(data.m));
}

TYPED_TEST(GridDataTest, Equality) {
using T = TypeParam;
GridData<T> data1;
data1.index_or_flag.set_index(123);
data1.scratch = Vector3<T>::Ones();
data1.v = Vector3<T>::Ones();
data1.m = 1;

GridData<T> data2;
data2.index_or_flag.set_index(123);
data2.scratch = Vector3<T>::Ones();
data2.v = Vector3<T>::Ones();
data2.m = 1;

EXPECT_EQ(data1, data2);
data2.index_or_flag.set_inactive();
EXPECT_NE(data1, data2);
data1.index_or_flag.set_inactive();
EXPECT_EQ(data1, data2);
data1.reset();
EXPECT_NE(data1, data2);
data2.reset();
EXPECT_EQ(data1, data2);
}

} // namespace
} // namespace internal
} // namespace mpm
} // namespace multibody
} // namespace drake

0 comments on commit b95c19c

Please sign in to comment.