diff --git a/tests/test_custom_ops.py b/tests/test_custom_ops.py index 7e8a78593..ab85e5fde 100644 --- a/tests/test_custom_ops.py +++ b/tests/test_custom_ops.py @@ -12,7 +12,6 @@ from xformers.components.attention._sputnik_sparse import _csr_to_coo from xformers.components.attention.core import ( _broadcast_batch, - _create_random_sparsity, _sparse_bmm, ) @@ -142,9 +141,7 @@ def test_sddmm_sputnik(device): prob = 0.5 a = torch.rand(B, L, K, device=device) b = torch.rand(B, M, K, device=device).transpose(-2, -1) - mask = _create_random_sparsity( - torch.ones(B, L, M, dtype=torch.bool, device=device), prob - ) + mask = torch.rand(B, L, M, device=device) > prob mask_csr = xformers.components.attention.core.SparseCS(mask, device) @@ -173,9 +170,7 @@ def test_sddmm_csr(L, M, K, prob): B = 8 a = torch.rand(B, L, K, device=device) b = torch.rand(B, M, K, device=device) - mask = _create_random_sparsity( - torch.ones(B, L, M, dtype=torch.bool, device=device), prob - ) + mask = torch.rand(B, L, M, device=device) > prob mask_csr = xformers.components.attention.core.SparseCS(mask, device) row_indices = mask_csr.row_indices @@ -193,7 +188,7 @@ def test_sddmm_csr(L, M, K, prob): @cuda_only -@pytest.mark.parametrize("nnz", [0, 4, 16, 20, 36]) +@pytest.mark.parametrize("nnz", [0, 1, 4, 16, 25, 36]) def test_sddmm_csr_per_nnz(nnz): device = torch.device("cuda") B = 8 @@ -226,13 +221,10 @@ def test_sddmm_csr_per_nnz(nnz): @pytest.mark.parametrize("L", [30, 17]) def test_sddmm_coo(L, M, K, prob): device = torch.device("cuda") - # TODO add more checks for different nnz B = 8 a = torch.rand(B, L, K, device=device) b = torch.rand(B, M, K, device=device) - mask = _create_random_sparsity( - torch.ones(B, L, M, dtype=torch.bool, device=device), prob - ) + mask = torch.rand(B, L, M, device=device) > prob mask_csr = xformers.components.attention.core.SparseCS(mask, device) row_indices = mask_csr.row_indices @@ -263,9 +255,7 @@ def test_sddmm_sputnik_backward(device): if not contiguous: a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_() b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_() - mask = _create_random_sparsity( - torch.ones(B, L, M, dtype=torch.bool, device=device), prob - ) + mask = torch.rand(B, L, M, device=device) > prob mask_csr = xformers.components.attention.core.SparseCS(mask, device) @@ -289,7 +279,8 @@ def test_sddmm_sputnik_backward(device): def test_sparse_softmax_sputnik(device): B, L = 8, 30 prob = 0.5 - a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob) + mask = torch.rand(B, L, L, device=device) > prob + a = torch.rand(B, L, L, device=device) * mask a_csr = xformers.components.attention.core.SparseCS(a, device) @@ -311,7 +302,8 @@ def test_sparse_softmax_sputnik(device): def test_sparse_softmax_sputnik_backward(device): B, L = 8, 30 prob = 0.5 - a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob) + mask = torch.rand(B, L, L, device=device) > prob + a = torch.rand(B, L, L, device=device) * mask a_csr = xformers.components.attention.core.SparseCS(a, device) @@ -334,7 +326,8 @@ def test_spmm_sputnik(device): B, L, K = 8, 30, 32 prob = 0.5 - a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob) + mask = torch.rand(B, L, L, device=device) > prob + a = torch.rand(B, L, L, device=device) * mask b = torch.rand(B, L, K, device=device) @@ -359,7 +352,8 @@ def test_spmm_sputnik_backward(device): B, M, L, K = 8, 16, 30, 32 prob = 0.5 - a = _create_random_sparsity(torch.rand(B, M, L, device=device), prob) + mask = torch.rand(B, M, L, device=device) > prob + a = torch.rand(B, M, L, device=device) * mask b = torch.rand(B, L, K, device=device) b.requires_grad_(True) @@ -391,7 +385,8 @@ def test_csr_transpose(): prob = 0.5 device = torch.device("cuda") - a = _create_random_sparsity(torch.rand(B, L, K, device=device), prob) + mask = torch.rand(B, L, K, device=device) > prob + a = torch.rand(B, L, K, device=device) * mask a_csr = xformers.components.attention.core.SparseCS(a, device) diff --git a/xformers/sparse/csr_tensor.py b/xformers/sparse/csr_tensor.py index 5ec9846c3..240af0a19 100644 --- a/xformers/sparse/csr_tensor.py +++ b/xformers/sparse/csr_tensor.py @@ -182,16 +182,19 @@ def _masked_matmul(cls, a, b, mask): row_indices = mask.__row_indices row_offsets = mask.__row_offsets column_indices = mask.__column_indices - a = a.contiguous() + values = mask._SparseCSRTensor__values out = _csr_ops._sddmm.apply( - a, + a.contiguous(), b.transpose(-2, -1).contiguous(), row_indices, row_offsets, column_indices, mask.__transp_info, ) - # TODO add bias here + if values.dtype == torch.bool: + out = torch.where(values, out, float("-inf")) + else: + out = out + values return cls._wrap( mask.shape, out, diff --git a/xformers/sparse/utils.py b/xformers/sparse/utils.py index b0ca16f2e..b0d33e8ce 100644 --- a/xformers/sparse/utils.py +++ b/xformers/sparse/utils.py @@ -97,12 +97,11 @@ def _dense_to_sparse(matrix, device): def _round_nnz(mask, divisible_by=4): - nonzero = torch.where(mask) - nnz = nonzero[0].shape[0] - nonzero = tuple(n[: (nnz - nnz % divisible_by)] for n in nonzero) - nm = torch.zeros_like(mask) - nm[nonzero] = True - return nm + nnz = torch.count_nonzero(mask) + cunz = torch.cumsum(~mask.flatten(), dim=0) + flip = cunz <= (-nnz) % divisible_by + + return torch.logical_or(mask, flip.reshape_as(mask)) def _dense3d_to_sparse(matrix, device):