Skip to content

Commit

Permalink
Implement t::geometry::TriangleMesh::RemoveUnreferencedVertices (#6640)
Browse files Browse the repository at this point in the history
The algorithm mimics the one in
geometry::TriangleMesh::RemoveUnreferencedVertices.
We first build a mask of vertices and then update all vertex attributes
by that mask. Triangles are left untouched.
  • Loading branch information
nsaiapova authored Mar 18, 2024
1 parent fa91f2e commit e8661f7
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 0 deletions.
54 changes: 54 additions & 0 deletions cpp/open3d/t/geometry/TriangleMesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,60 @@ TriangleMesh TriangleMesh::SelectByIndex(const core::Tensor &indices) const {
return result;
}

TriangleMesh TriangleMesh::RemoveUnreferencedVertices() {
if (!HasVertexPositions() || GetVertexPositions().GetLength() == 0) {
utility::LogWarning(
"[RemoveUnreferencedVertices] TriangleMesh has no vertices.");
return *this;
}
GetVertexAttr().AssertSizeSynchronized();

core::Dtype tri_dtype = HasTriangleIndices()
? GetTriangleIndices().GetDtype()
: core::Int64;

int64_t num_verts_old = GetVertexPositions().GetLength();
// int mask for vertices as we need to remap indices.
core::Tensor vertex_mask = core::Tensor::Zeros({num_verts_old}, tri_dtype);

if (!HasTriangleIndices() || GetTriangleIndices().GetLength() == 0) {
utility::LogWarning(
"[RemoveUnreferencedVertices] TriangleMesh has no triangles. "
"Removing all vertices.");
// in this case we need to empty vertices and their attributes
} else {
GetTriangleAttr().AssertSizeSynchronized();
core::Tensor tris_cpu =
GetTriangleIndices().To(core::Device()).Contiguous();
DISPATCH_INT_DTYPE_PREFIX_TO_TEMPLATE(tri_dtype, tris, [&]() {
scalar_tris_t *tris_ptr = tris_cpu.GetDataPtr<scalar_tris_t>();
scalar_tris_t *vertex_mask_ptr =
vertex_mask.GetDataPtr<scalar_tris_t>();
for (int i = 0; i < tris_cpu.GetLength(); i++) {
vertex_mask_ptr[tris_ptr[3 * i]] = 1;
vertex_mask_ptr[tris_ptr[3 * i + 1]] = 1;
vertex_mask_ptr[tris_ptr[3 * i + 2]] = 1;
}

UpdateTriangleIndicesByVertexMask<scalar_tris_t>(tris_cpu,
vertex_mask);
});
}

// send the vertex mask to original device and apply to
// vertices
vertex_mask = vertex_mask.To(GetDevice(), core::Bool);
for (auto item : GetVertexAttr()) {
SetVertexAttr(item.first, item.second.IndexGet({vertex_mask}));
}

utility::LogDebug(
"[RemoveUnreferencedVertices] {:d} vertices have been removed.",
(int)(num_verts_old - GetVertexPositions().GetLength()));

return *this;
}

} // namespace geometry
} // namespace t
} // namespace open3d
4 changes: 4 additions & 0 deletions cpp/open3d/t/geometry/TriangleMesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,10 @@ class TriangleMesh : public Geometry, public DrawableGeometry {
/// an empty mesh.
TriangleMesh SelectByIndex(const core::Tensor &indices) const;

/// Removes unreferenced vertices from the mesh.
/// \return The reference to itself.
TriangleMesh RemoveUnreferencedVertices();

protected:
core::Device device_ = core::Device("CPU:0");
TensorMap vertex_attr_;
Expand Down
4 changes: 4 additions & 0 deletions cpp/pybind/t/geometry/trianglemesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,10 @@ or has a negative value, it is ignored.
box = o3d.t.geometry.TriangleMesh.create_box()
top_face = box.select_by_index([2, 3, 6, 7])
)");

triangle_mesh.def("remove_unreferenced_vertices",
&TriangleMesh::RemoveUnreferencedVertices,
"Removes unreferenced vertices from the mesh in-place.");
}

} // namespace geometry
Expand Down
130 changes: 130 additions & 0 deletions cpp/tests/t/geometry/TriangleMesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1212,5 +1212,135 @@ TEST_P(TriangleMeshPermuteDevices, SelectByIndex) {
box_untouched.GetTriangleIndices()));
}

TEST_P(TriangleMeshPermuteDevices, RemoveUnreferencedVertices) {
core::Device device = GetParam();
t::geometry::TriangleMesh mesh_empty{device};

// check completely empty mesh
EXPECT_TRUE(mesh_empty.RemoveUnreferencedVertices().IsEmpty());

// check mesh w/o triangles
core::Tensor vertices_no_tris_orig =
core::Tensor::Ones({2, 3}, core::Float32, device);
mesh_empty.SetVertexPositions(vertices_no_tris_orig);
EXPECT_TRUE(mesh_empty.RemoveUnreferencedVertices().IsEmpty());

// Torus
core::Tensor verts = core::Tensor::Init<double>(
{
{0, 0, 0}, /* 0 */
{3.0, 0.0, 0.0},
{1.5, 0.0, 0.866025},
{1, 2, 3}, /* 3 */
{1.5, 0.0, -0.866025},
{1.5, 2.59808, 0.0},
{0.75, 1.29904, 0.866025},
{0.75, 1.29904, -0.866025},
{-1.5, 2.59808, 0},
{-0.75, 1.29904, 0.866025},
{-0.75, 1.29904, -0.866025},
{-3.0, 0.0, 0.0},
{-1.5, 0.0, 0.866025},
{-1.5, 0.0, -0.866025},
{-1.5, -2.59808, 0.0},
{-0.75, -1.29904, 0.866025},
{-0.75, -1.29904, -0.866025},
{4, 5, 6}, /* 17 */
{1.5, -2.59808, 0.0},
{0.75, -1.29904, 0.866025},
{0.75, -1.29904, -0.866025},
{7, 8, 9} /* 21 */
},
device);

core::Tensor tris = core::Tensor::Init<int32_t>(
{{5, 6, 1}, {1, 6, 2}, {6, 7, 2}, {2, 7, 4},
{7, 5, 4}, {4, 5, 1}, {8, 9, 5}, {5, 9, 6},
{9, 10, 6}, {6, 10, 7}, {10, 8, 7}, {7, 8, 5},
{11, 12, 8}, {8, 12, 9}, {12, 13, 9}, {9, 13, 10},
{13, 11, 10}, {10, 11, 8}, {14, 15, 11}, {11, 15, 12},
{15, 16, 12}, {12, 16, 13}, {16, 14, 13}, {13, 14, 11},
{18, 19, 14}, {14, 19, 15}, {19, 20, 15}, {15, 20, 16},
{20, 18, 16}, {16, 18, 14}, {1, 2, 18}, {18, 2, 19},
{2, 4, 19}, {19, 4, 20}, {4, 1, 20}, {20, 1, 18}},
device);
t::geometry::TriangleMesh torus{verts, tris};
core::Tensor vertex_colors = core::Tensor::Init<float>(
{{0.0, 0.0, 0.0}, {1.0, 1.0, 1.0}, {2.0, 2.0, 2.0},
{3.0, 3.0, 3.0}, {4.0, 4.0, 4.0}, {5.0, 5.0, 5.0},
{6.0, 6.0, 6.0}, {7.0, 7.0, 7.0}, {8.0, 8.0, 8.0},
{9.0, 9.0, 9.0}, {10.0, 10.0, 10.0}, {11.0, 11.0, 11.0},
{12.0, 12.0, 12.0}, {13.0, 13.0, 13.0}, {14.0, 14.0, 14.0},
{15.0, 15.0, 15.0}, {16.0, 16.0, 16.0}, {17.0, 17.0, 17.0},
{18.0, 18.0, 18.0}, {19.0, 19.0, 19.0}, {20.0, 20.0, 20.0},
{21.0, 21.0, 21.0}},
device);
core::Tensor vertex_labels =
core::Tensor::Init<float>(
{{0.0, 0.0, 0.0}, {1.0, 1.0, 1.0}, {2.0, 2.0, 2.0},
{3.0, 3.0, 3.0}, {4.0, 4.0, 4.0}, {5.0, 5.0, 5.0},
{6.0, 6.0, 6.0}, {7.0, 7.0, 7.0}, {8.0, 8.0, 8.0},
{9.0, 9.0, 9.0}, {10.0, 10.0, 10.0}, {11.0, 11.0, 11.0},
{12.0, 12.0, 12.0}, {13.0, 13.0, 13.0}, {14.0, 14.0, 14.0},
{15.0, 15.0, 15.0}, {16.0, 16.0, 16.0}, {17.0, 17.0, 17.0},
{18.0, 18.0, 18.0}, {19.0, 19.0, 19.0}, {20.0, 20.0, 20.0},
{21.0, 21.0, 21.0}},
device) *
10;

core::Tensor triangle_labels =
core::Tensor::Init<float>({{0.0, 0.0, 0.0}, {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0}, {3.0, 3.0, 3.0},
{4.0, 4.0, 4.0}, {5.0, 5.0, 5.0},
{6.0, 6.0, 6.0}, {7.0, 7.0, 7.0},
{8.0, 8.0, 8.0}, {9.0, 9.0, 9.0},
{10.0, 10.0, 10.0}, {11.0, 11.0, 11.0},
{12.0, 12.0, 12.0}, {13.0, 13.0, 13.0},
{14.0, 14.0, 14.0}, {15.0, 15.0, 15.0},
{16.0, 16.0, 16.0}, {17.0, 17.0, 17.0},
{18.0, 18.0, 18.0}, {19.0, 19.0, 19.0},
{20.0, 20.0, 20.0}, {21.0, 21.0, 21.0},
{22.0, 22.0, 22.0}, {23.0, 23.0, 23.0},
{24.0, 24.0, 24.0}, {25.0, 25.0, 25.0},
{26.0, 26.0, 26.0}, {27.0, 27.0, 27.0},
{28.0, 28.0, 28.0}, {29.0, 29.0, 29.0},
{30.0, 30.0, 30.0}, {31.0, 31.0, 31.0},
{32.0, 32.0, 32.0}, {33.0, 33.0, 33.0},
{34.0, 34.0, 34.0}, {35.0, 35.0, 35.0}},
device) *
100;
torus.SetVertexColors(vertex_colors);
torus.SetVertexAttr("labels", vertex_labels);
torus.ComputeVertexNormals();
torus.ComputeTriangleNormals();

// set the expected value
core::Tensor verts_mask = core::Tensor::Init<bool>(
{0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0},
device);
core::Tensor expected_verts =
torus.GetVertexPositions().IndexGet({verts_mask});
core::Tensor expected_tris =
t::geometry::TriangleMesh::CreateTorus(2, 1, 6, 3, core::Float32,
core::Int32, device)
.GetTriangleIndices();
core::Tensor expected_vert_normals =
torus.GetVertexNormals().IndexGet({verts_mask});
core::Tensor expected_tri_normals = torus.GetTriangleNormals().Clone();
core::Tensor expected_vert_labels =
torus.GetVertexAttr("labels").IndexGet({verts_mask});
core::Tensor expected_vert_colors =
torus.GetVertexAttr("colors").IndexGet({verts_mask});

torus.RemoveUnreferencedVertices();

EXPECT_TRUE(torus.GetVertexPositions().AllClose(expected_verts));
EXPECT_TRUE(torus.GetVertexNormals().AllClose(expected_vert_normals));
EXPECT_TRUE(torus.GetVertexColors().AllClose(expected_vert_colors));
EXPECT_TRUE(torus.GetVertexAttr("labels").AllClose(expected_vert_labels));
EXPECT_TRUE(torus.GetTriangleIndices().AllClose(expected_tris));
EXPECT_TRUE(torus.GetTriangleNormals().AllClose(expected_tri_normals));
}

} // namespace tests
} // namespace open3d
49 changes: 49 additions & 0 deletions python/test/t/geometry/test_trianglemesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,52 @@ def test_select_by_index_64(device):
untouched_sphere.vertex.positions)
assert sphere_custom.triangle.indices.allclose(
untouched_sphere.triangle.indices)


def check_no_unreferenced_vertices(device, int_t, float_t):
sphere = o3d.t.geometry.TriangleMesh.create_sphere(1, 3, float_t, int_t,
device)
expected_sphere = o3d.t.geometry.TriangleMesh.create_sphere(
1, 3, float_t, int_t, device)

sphere.remove_unreferenced_vertices()

# nothing should be removed
assert sphere.vertex.positions.allclose(expected_sphere.vertex.positions)
assert sphere.triangle.indices.allclose(expected_sphere.triangle.indices)


def check_remove_unreferenced_vertices(device, int_t, float_t):
expected_mobius = o3d.t.geometry.TriangleMesh.create_mobius(
10, 2, 1, 1, 1, 1, 1, float_t, int_t, device)

verts = o3c.Tensor(
[[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.424307, 0.308277, -0.154508],
[1.19373, 0.867294, 0.154508], [0.184017, 0.566346, -0.293893],
[0.434017, 1.33577, 0.293893], [-0.218199, 0.671548, -0.404508],
[-0.399835, 1.23057, 0.404508], [-0.684017, 0.496967, -0.475528],
[-0.934017, 0.678603, 0.475528], [-1.0, 0.0, -0.5], [-1.0, 0.0, 0.5],
[-0.934017, -0.678603, -0.475528], [-0.684017, -0.496967, 0.475528],
[-0.399835, -1.23057, -0.404508], [-0.218199, -0.671548, 0.404508],
[0.434017, -1.33577, -0.293893], [0.184017, -0.566346, 0.293893],
[0, 0, 0], [1.19373, -0.867294, -0.154508], [1, 1, 1],
[0.424307, -0.308277, 0.154508]], float_t, device)

tris = o3c.Tensor(
[[0, 3, 1], [0, 2, 3], [3, 2, 4], [3, 4, 5], [4, 7, 5], [4, 6, 7],
[7, 6, 8], [7, 8, 9], [8, 11, 9], [8, 10, 11], [11, 10, 12],
[11, 12, 13], [12, 15, 13], [12, 14, 15], [15, 14, 16], [15, 16, 17],
[16, 21, 17], [16, 19, 21], [19, 21, 1], [1, 21, 0]], int_t, device)

mobius = o3d.t.geometry.TriangleMesh(verts, tris)
mobius.remove_unreferenced_vertices()
assert mobius.vertex.positions.allclose(expected_mobius.vertex.positions)
assert mobius.triangle.indices.allclose(expected_mobius.triangle.indices)


@pytest.mark.parametrize("device", list_devices())
@pytest.mark.parametrize("int_t", (o3c.int32, o3c.int64))
@pytest.mark.parametrize("float_t", (o3c.float32, o3c.float64))
def test_remove_unreferenced_vertices(device, int_t, float_t):
check_no_unreferenced_vertices(device, int_t, float_t)
check_remove_unreferenced_vertices(device, int_t, float_t)

0 comments on commit e8661f7

Please sign in to comment.