Skip to content

Commit

Permalink
Use correct devices for attr tensors in SelectByIndex.
Browse files Browse the repository at this point in the history
  • Loading branch information
ssheorey committed Dec 29, 2023
1 parent a590c77 commit f382849
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ option(BUILD_BENCHMARKS "Build the micro benchmarks" OFF
option(BUILD_PYTHON_MODULE "Build the python module" ON )
option(BUILD_CUDA_MODULE "Build the CUDA module" OFF)
option(BUILD_COMMON_CUDA_ARCHS "Build for common CUDA GPUs (for release)" OFF)
option(ENABLE_CACHED_CUDA_MANAGER "Enable cached CUDA memory manager" ON )
if (WIN32) # Causes CUDA runtime error on Windows (See issue #6555)
option(ENABLE_CACHED_CUDA_MANAGER "Enable cached CUDA memory manager" OFF)
else()
option(ENABLE_CACHED_CUDA_MANAGER "Enable cached CUDA memory manager" ON )
endif()
if(NOT LINUX_AARCH64 AND NOT APPLE_AARCH64)
option(BUILD_ISPC_MODULE "Build the ISPC module" ON )
else()
Expand Down
21 changes: 9 additions & 12 deletions cpp/open3d/t/geometry/TriangleMesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1029,8 +1029,8 @@ int TriangleMesh::PCAPartition(int max_faces) {
}

/// A helper to compute new vertex indices out of vertex mask.
/// \param tris_cpu tensor with triangle indices to update.
/// \param vertex_mask tensor with the mask for vertices.
/// \param tris_cpu CPU tensor with triangle indices to update.
/// \param vertex_mask CPU tensor with the mask for vertices.
template <typename T>
static void UpdateTriangleIndicesByVertexMask(core::Tensor &tris_cpu,
const core::Tensor &vertex_mask) {
Expand Down Expand Up @@ -1148,11 +1148,10 @@ static bool IsNegative(T val) {
}

TriangleMesh TriangleMesh::SelectByIndex(const core::Tensor &indices) const {
TriangleMesh result;
core::AssertTensorShape(indices, {indices.GetLength()});
if (!HasVertexPositions()) {
utility::LogWarning("[SelectByIndex] TriangleMesh has no vertices.");
return result;
return {};
}
GetVertexAttr().AssertSizeSynchronized();

Expand Down Expand Up @@ -1194,7 +1193,7 @@ TriangleMesh TriangleMesh::SelectByIndex(const core::Tensor &indices) const {
scalar_tris_t *vertex_mask_ptr =
vertex_mask.GetDataPtr<scalar_tris_t>();
const scalar_indices_t *indices_ptr =
indices.GetDataPtr<scalar_indices_t>();
indices_cpu.GetDataPtr<scalar_indices_t>();
for (int64_t i = 0; i < indices.GetLength(); ++i) {
if (IsNegative(indices_ptr[i]) ||
indices_ptr[i] >=
Expand Down Expand Up @@ -1233,16 +1232,14 @@ TriangleMesh TriangleMesh::SelectByIndex(const core::Tensor &indices) const {
});
});

// send the vertex mask to original device and apply to vertices
// send the vertex mask and triangle mask to original device and apply to
// vertices
vertex_mask = vertex_mask.To(GetDevice(), core::Bool);
tri_mask = tri_mask.To(GetDevice());
core::Tensor new_vertices = GetVertexPositions().IndexGet({vertex_mask});
result.SetVertexPositions(new_vertices);

if (HasTriangleIndices()) {
// select triangles and send the selected ones to the original device
result.SetTriangleIndices(tris_cpu.To(GetDevice()));
}
core::Tensor new_tris = tris_cpu.To(GetDevice());

TriangleMesh result(new_vertices, new_tris);
CopyAttributesByMasks(result, *this, vertex_mask, tri_mask);

return result;
Expand Down

0 comments on commit f382849

Please sign in to comment.