Skip to content

Commit

Permalink
fixup! Implement open3d::t::geometry::TriangleMesh::SelectByIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
nsaiapova committed Oct 13, 2023
1 parent 1b1c507 commit 790fc34
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 30 deletions.
42 changes: 29 additions & 13 deletions cpp/open3d/t/geometry/TriangleMesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,9 +1061,9 @@ TriangleMesh TriangleMesh::SelectFacesByMask(const core::Tensor &mask) const {
return result;
}

// A helper to compute the vertex and triangle masks based on indices.
// Additionally updates tris_cpu to new indices.
template <typename intT>
/// \brief A helper to compute the vertex and triangle masks based on indices.
/// Additionally updates tris_cpu to new indices.
template <typename tri_intT, typename indices_intT>
static void SBIUpdateMasksAndTrisCPUHelper(const core::Tensor &indices,
core::Tensor &vertex_mask,
core::Tensor &tris_mask,
Expand All @@ -1072,25 +1072,29 @@ static void SBIUpdateMasksAndTrisCPUHelper(const core::Tensor &indices,
const int64_t num_verts = vertex_mask.GetLength();

// compute the vertices mask
intT *vertex_mask_ptr = vertex_mask.GetDataPtr<intT>();
const intT *indices_ptr = indices.GetDataPtr<intT>();
tri_intT *vertex_mask_ptr = vertex_mask.GetDataPtr<tri_intT>();
const indices_intT *indices_ptr = indices.GetDataPtr<indices_intT>();
for (int64_t i = 0; i < indices.GetLength(); ++i) {
if (indices_ptr[i] < 0) {
utility::LogError(
"[SelectByIndex] indices contains a negative index {}. ",
indices_ptr[i]);
}
if (indices_ptr[i] >= num_verts) {
utility::LogError(
"[SelectByIndex] indices contains index {} out of range. ",
indices_ptr[i]);
continue;
}
vertex_mask_ptr[indices_ptr[i]] = 1;
}

// compute new vertix indices
std::vector<intT> prefix_sum(num_verts + 1, 0);
std::vector<tri_intT> prefix_sum(num_verts + 1, 0);
utility::InclusivePrefixSum(vertex_mask_ptr, vertex_mask_ptr + num_verts,
&prefix_sum[1]);

// update the triangles with new indices and build the triangle mask
intT *tris_cpu_ptr = tris_cpu.GetDataPtr<intT>();
tri_intT *tris_cpu_ptr = tris_cpu.GetDataPtr<tri_intT>();
bool *tris_mask_ptr = tris_mask.GetDataPtr<bool>();
for (int64_t i = 0; i < num_tris; ++i) {
if (vertex_mask_ptr[tris_cpu_ptr[3 * i]] == 1 &&
Expand All @@ -1105,10 +1109,17 @@ static void SBIUpdateMasksAndTrisCPUHelper(const core::Tensor &indices,
}

TriangleMesh TriangleMesh::SelectByIndex(const core::Tensor &indices) const {
core::AssertTensorShape(indices, {indices.GetLength()});
if (!HasTriangleIndices()) {
utility::LogError("[SelectByIndex] mesh has no triangle indices.");
}
if (!HasVertexPositions()) {
utility::LogError("[SelectByIndex] mesh has no vertex positions.");
}
GetTriangleAttr().AssertSizeSynchronized();
GetVertexAttr().AssertSizeSynchronized();
if (GetTriangleIndices().GetDtype() == core::Int32) {
core::AssertTensorDtype(indices, core::Int32);
core::AssertTensorDtype(indices, GetTriangleIndices().GetDtype());
} else {
// we allow both Int32 and Int64 if the mesh indicies are Int64
core::AssertTensorDtypes(indices, {core::Int32, core::Int64});
Expand All @@ -1130,11 +1141,16 @@ TriangleMesh TriangleMesh::SelectByIndex(const core::Tensor &indices) const {

// compute vertex and triangular masks and triangles based on indices
if (tris_cpu.GetDtype() == core::Int32) {
SBIUpdateMasksAndTrisCPUHelper<int32_t>(indices_cpu, vertex_mask,
tris_mask, tris_cpu);
SBIUpdateMasksAndTrisCPUHelper<int32_t, int32_t>(
indices_cpu, vertex_mask, tris_mask, tris_cpu);
} else {
SBIUpdateMasksAndTrisCPUHelper<int64_t>(indices_cpu, vertex_mask,
tris_mask, tris_cpu);
if (indices_cpu.GetDtype() == core::Int32) {
SBIUpdateMasksAndTrisCPUHelper<int64_t, int32_t>(
indices_cpu, vertex_mask, tris_mask, tris_cpu);
} else {
SBIUpdateMasksAndTrisCPUHelper<int64_t, int64_t>(
indices_cpu, vertex_mask, tris_mask, tris_cpu);
}
}

// select triangles and send the selected ones to the original device
Expand Down
7 changes: 4 additions & 3 deletions cpp/open3d/t/geometry/TriangleMesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -931,9 +931,10 @@ class TriangleMesh : public Geometry, public DrawableGeometry {
TriangleMesh SelectFacesByMask(const core::Tensor &mask) const;

/// Returns a new mesh with the vertices selected by a vector of indices.
/// Throws an exception if an item from the indices list exceeds the max
/// vertex number of the mesh.
/// \param indices An integer list of indices. Duplicates are
/// Throws an exception if the mesh is empty or if an item from the indices
/// list exceeds the max vertex number of the mesh or a negative value was
/// supplied.
/// \param indices An integer list of non-negative indices. Duplicates are
/// allowed, but ignored. If vertex indices of the mesh are of type Int64,
/// both Int32 and Int64 are allowed as indices type, otherwise only Int32
/// is accepted.
Expand Down
7 changes: 3 additions & 4 deletions cpp/pybind/t/geometry/trianglemesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,8 @@ the partition id for each face.
triangle_mesh.def(
"select_by_index", &TriangleMesh::SelectByIndex, "indices"_a,
R"(Returns a new mesh with the vertices selected according to the indices list.
Throws an exception if an item from the indices list exceeds the max vertex
number of the mesh,
Throws an exception if the mesh is empty or if an item from the indices list exceeds
the max vertex number of the mesh or a negative value was supplied.
Args:
indices (open3d.core.Tensor): An integer list of indices. Duplicates are
Expand All @@ -942,8 +942,7 @@ number of the mesh,
Example:
This code selets the top face of a box, which has indices [2, 3, 6, 7].
parts::
This code selects the top face of a box, which has indices [2, 3, 6, 7]::
import open3d as o3d
import numpy as np
Expand Down
34 changes: 29 additions & 5 deletions cpp/tests/t/geometry/TriangleMesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,12 @@ TEST_P(TriangleMeshPermuteDevices, CreateMobius) {
triangle_indices_custom));
}

TEST_P(TriangleMeshPermuteDevices, SelectByIndex_Box) {
TEST_P(TriangleMeshPermuteDevices, SelectByIndex) {
// check that an exception is thrown if the mesh is empty
t::geometry::TriangleMesh mesh_empty;
core::Tensor indices_empty = core::Tensor::Init<int64_t>({});
EXPECT_THROW(mesh_empty.SelectByIndex(indices_empty), std::runtime_error);

// create box with normals, colors and labels defined.
t::geometry::TriangleMesh box = t::geometry::TriangleMesh::CreateBox();
core::Tensor vertex_colors = core::Tensor::Init<float>({{0.0, 0.0, 0.0},
Expand Down Expand Up @@ -982,10 +987,10 @@ TEST_P(TriangleMeshPermuteDevices, SelectByIndex_Box) {
box.ComputeTriangleNormals();
box.SetTriangleAttr("labels", triangle_labels);

core::Tensor indices = core::Tensor::Init<int64_t>({2, 3, 6, 7});
t::geometry::TriangleMesh selected = box.SelectByIndex(indices);
// empty index list
EXPECT_TRUE(box.SelectByIndex(indices_empty).IsEmpty());

// Set the expected values.
// set the expected valuee
core::Tensor expected_verts = core::Tensor::Init<float>({{0.0, 0.0, 1.0},
{1.0, 0.0, 1.0},
{0.0, 1.0, 1.0},
Expand All @@ -1000,7 +1005,6 @@ TEST_P(TriangleMeshPermuteDevices, SelectByIndex_Box) {
{30.0, 30.0, 30.0},
{60.0, 60.0, 60.0},
{70.0, 70.0, 70.0}});

core::Tensor expected_tris =
core::Tensor::Init<int64_t>({{0, 1, 3}, {0, 3, 2}});
core::Tensor tris_mask =
Expand All @@ -1010,6 +1014,9 @@ TEST_P(TriangleMeshPermuteDevices, SelectByIndex_Box) {
core::Tensor expected_tri_labels = core::Tensor::Init<float>(
{{800.0, 800.0, 800.0}, {900.0, 900.0, 900.0}});

core::Tensor indices = core::Tensor::Init<int64_t>({2, 3, 6, 7});
t::geometry::TriangleMesh selected = box.SelectByIndex(indices);

EXPECT_TRUE(selected.GetVertexPositions().AllClose(expected_verts));
EXPECT_TRUE(selected.GetVertexColors().AllClose(expected_vert_colors));
EXPECT_TRUE(
Expand All @@ -1019,6 +1026,23 @@ TEST_P(TriangleMeshPermuteDevices, SelectByIndex_Box) {
EXPECT_TRUE(
selected.GetTriangleAttr("labels").AllClose(expected_tri_labels));

core::Tensor indices_duplicate =
core::Tensor::Init<int32_t>({2, 2, 3, 3, 6, 7, 7});
t::geometry::TriangleMesh selected_duplicate =
box.SelectByIndex(indices_duplicate);
EXPECT_TRUE(
selected_duplicate.GetVertexPositions().AllClose(expected_verts));
EXPECT_TRUE(selected_duplicate.GetVertexColors().AllClose(
expected_vert_colors));
EXPECT_TRUE(selected_duplicate.GetVertexAttr("labels").AllClose(
expected_vert_labels));
EXPECT_TRUE(
selected_duplicate.GetTriangleIndices().AllClose(expected_tris));
EXPECT_TRUE(selected_duplicate.GetTriangleNormals().AllClose(
expected_tri_normals));
EXPECT_TRUE(selected_duplicate.GetTriangleAttr("labels").AllClose(
expected_tri_labels));

// Check that initial mesh is unchanged.
t::geometry::TriangleMesh box_untouched =
t::geometry::TriangleMesh::CreateBox();
Expand Down
79 changes: 74 additions & 5 deletions python/test/t/geometry/test_trianglemesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def test_pickle(device):


@pytest.mark.parametrize("device", list_devices())
def test_select_by_index(device):
def test_select_by_index_32(device):
sphere_custom = o3d.t.geometry.TriangleMesh.create_sphere(
1, 3, o3c.float64, o3c.int32, device)

Expand All @@ -432,10 +432,20 @@ def test_select_by_index(device):
expected_tris = o3c.Tensor([[0, 1, 2], [0, 3, 4], [0, 4, 5], [0, 5, 1]],
o3c.int32, device)

indices = o3c.Tensor([0, 2, 3, 5, 6, 7], o3c.int64, device)
# check indices shape mismatch
indices_2d = o3c.Tensor([[0, 2], [3, 5], [6, 7]], o3c.int32, device)
with pytest.raises(RuntimeError):
selected = sphere_custom.select_by_index(indices_2d)

# check indices int size mismatch
indices_64 = o3c.Tensor([0, 2, 3, 5, 6, 7], o3c.int64, device)
with pytest.raises(RuntimeError):
selected = sphere_custom.select_by_index(indices_64)

# check indices type mismatch
with pytest.raises(RuntimeError) as e:
selected = sphere_custom.select_by_index(indices)
indices_float = o3c.Tensor([2.0, 4.0], o3c.float32, device)
with pytest.raises(RuntimeError):
selected = sphere_custom.select_by_index(indices_float)

# check the expected mesh
indices = o3c.Tensor([0, 2, 3, 5, 6, 7], o3c.int32, device)
Expand All @@ -453,5 +463,64 @@ def test_select_by_index(device):

# check that the exception is thrown if one of the indices exceeds
# the max vertex index of the mesh
with pytest.raises(RuntimeError) as e:
with pytest.raises(RuntimeError):
selected = sphere_custom.select_by_index([2, 3, 6, 99])

# check that the exception is thrown if one of the indices have a negative
# value
with pytest.raises(RuntimeError):
selected = sphere_custom.select_by_index([2, 3, 6, -7])


@pytest.mark.parametrize("device", list_devices())
def test_select_by_index_64(device):
sphere_custom = o3d.t.geometry.TriangleMesh.create_sphere(
1, 3, o3c.float64, o3c.int64, device)

# check indices shape mismatch
with pytest.raises(RuntimeError):
indices_2d = o3c.Tensor([[0, 2], [3, 5], [6, 7]], o3c.int64, device)
selected = sphere_custom.select_by_index(indices_2d)

# check indices type mismatch
with pytest.raises(RuntimeError):
indices_float = o3c.Tensor([2.0, 4.0], o3c.float64, device)
selected = sphere_custom.select_by_index(indices_float)

expected_verts = o3c.Tensor(
[[0.0, 0.0, 1.0], [0.866025, 0, 0.5], [0.433013, 0.75, 0.5],
[-0.866025, 0.0, 0.5], [-0.433013, -0.75, 0.5], [0.433013, -0.75, 0.5]
], o3c.float64, device)

expected_tris = o3c.Tensor([[0, 1, 2], [0, 3, 4], [0, 4, 5], [0, 5, 1]],
o3c.int64, device)

# check the expected mesh with int64 input
indices_64 = o3c.Tensor([0, 2, 3, 5, 6, 7], o3c.int64, device)
selected = sphere_custom.select_by_index(indices_64)
assert selected.vertex.positions.allclose(expected_verts)
assert selected.triangle.indices.allclose(expected_tris)

# check the expected mesh with int32 input and unsorted indices
indices_32 = o3c.Tensor([7, 6, 3, 5, 0, 2], o3c.int32, device)
selected = sphere_custom.select_by_index(indices_32)
assert selected.vertex.positions.allclose(expected_verts)
assert selected.triangle.indices.allclose(expected_tris)

# check that the original mesh is unmodified
untouched_sphere = o3d.t.geometry.TriangleMesh.create_sphere(
1, 3, o3c.float64, o3c.int64, device)
assert sphere_custom.vertex.positions.allclose(
untouched_sphere.vertex.positions)
assert sphere_custom.triangle.indices.allclose(
untouched_sphere.triangle.indices)

# check that the exception is thrown if one of the indices exceeds
# the max vertex index of the mesh
with pytest.raises(RuntimeError):
selected = sphere_custom.select_by_index([2, 3, 6, 99])

# check that the exception is thrown if one of the indices have a negative
# value
with pytest.raises(RuntimeError):
selected = sphere_custom.select_by_index([2, 3, 6, -7])

0 comments on commit 790fc34

Please sign in to comment.