From 2603985282507b93cefea5e312baaffc30980a12 Mon Sep 17 00:00:00 2001 From: kadmin Date: Sat, 25 Jun 2022 21:34:23 +0000 Subject: [PATCH] Update test to include non-uniform case --- pytorch3d/structures/meshes.py | 18 +++++++++++------- tests/test_meshes.py | 23 +++++++++++++++++------ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 97e0055a8..a58e273b3 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1559,8 +1559,16 @@ def volume_centroid(self): """ v_idxs = self.faces_padded().split([1, 1, 1], dim=-1) verts = self.verts_padded() - - v0, v1, v2 = [torch.gather(verts, 1, idx.expand(-1, -1, 3)) for idx in v_idxs] + valid = (self.faces_padded() != -1).all(dim=-1, keepdim=True) + + v0, v1, v2 = [ + torch.gather( + verts, + 1, + idx.where(valid, torch.zeros_like(idx)).expand(-1, -1, 3), + ).where(valid, torch.zeros_like(idx, dtype=verts.dtype)) + for idx in v_idxs + ] tetra_center = (v0 + v1 + v2) / 4 signed_tetra_vol = (v0 * torch.cross(v1, v2, dim=-1)).sum( @@ -1568,11 +1576,7 @@ def volume_centroid(self): ) / 6 denom = signed_tetra_vol.sum(dim=-2) # clamp the denominator to prevent instability for degenerate meshes. - denom = torch.where( - denom < 0, - denom.clamp(max=-1e-5), - denom.clamp(min=1e-5) - ) + denom = torch.where(denom < 0, denom.clamp(max=-1e-5), denom.clamp(min=1e-5)) return (tetra_center * signed_tetra_vol).sum(dim=-2) / denom def submeshes( diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 74afe33f2..6f9f8de1f 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -1299,13 +1299,24 @@ def test_assigned_normals(self): self.assertFalse(torch.allclose(yes_normals.verts_normals_padded(), verts)) def test_centroid(self): + meshes = init_simple_mesh() + # Check that it returns a valid value for multiple meshes with an inconsistent number + # of vertices + meshes.volume_centroid() + cube = init_cube_meshes() - self.assertClose(cube.volume_centroid(), torch.tensor([ - [0.5] * 3, - [1.5] * 3, - [2.5] * 3, - [3.5] * 3, - ])) + self.assertClose( + cube.volume_centroid(), + torch.tensor( + [ + [0.5] * 3, + [1.5] * 3, + [2.5] * 3, + [3.5] * 3, + ] + ), + ) + def test_submeshes(self): empty_mesh = Meshes([], []) # Four cubes with offsets [0, 1, 2, 3].