Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _round_nnz #1206

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions tests/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from xformers.components.attention._sputnik_sparse import _csr_to_coo
from xformers.components.attention.core import (
_broadcast_batch,
_create_random_sparsity,
_sparse_bmm,
)

Expand Down Expand Up @@ -142,9 +141,7 @@ def test_sddmm_sputnik(device):
prob = 0.5
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device).transpose(-2, -1)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask = torch.rand(B, L, M, device=device) > prob

mask_csr = xformers.components.attention.core.SparseCS(mask, device)

Expand Down Expand Up @@ -173,9 +170,7 @@ def test_sddmm_csr(L, M, K, prob):
B = 8
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask = torch.rand(B, L, M, device=device) > prob

mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
Expand All @@ -193,7 +188,7 @@ def test_sddmm_csr(L, M, K, prob):


@cuda_only
@pytest.mark.parametrize("nnz", [0, 4, 16, 20, 36])
@pytest.mark.parametrize("nnz", [0, 1, 4, 16, 25, 36])
def test_sddmm_csr_per_nnz(nnz):
device = torch.device("cuda")
B = 8
Expand Down Expand Up @@ -226,13 +221,10 @@ def test_sddmm_csr_per_nnz(nnz):
@pytest.mark.parametrize("L", [30, 17])
def test_sddmm_coo(L, M, K, prob):
device = torch.device("cuda")
# TODO add more checks for different nnz
B = 8
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask = torch.rand(B, L, M, device=device) > prob

mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
Expand Down Expand Up @@ -263,9 +255,7 @@ def test_sddmm_sputnik_backward(device):
if not contiguous:
a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask = torch.rand(B, L, M, device=device) > prob

mask_csr = xformers.components.attention.core.SparseCS(mask, device)

Expand All @@ -289,7 +279,8 @@ def test_sddmm_sputnik_backward(device):
def test_sparse_softmax_sputnik(device):
B, L = 8, 30
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
mask = torch.rand(B, L, L, device=device) > prob
a = torch.rand(B, L, L, device=device) * mask

a_csr = xformers.components.attention.core.SparseCS(a, device)

Expand All @@ -311,7 +302,8 @@ def test_sparse_softmax_sputnik(device):
def test_sparse_softmax_sputnik_backward(device):
B, L = 8, 30
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
mask = torch.rand(B, L, L, device=device) > prob
a = torch.rand(B, L, L, device=device) * mask

a_csr = xformers.components.attention.core.SparseCS(a, device)

Expand All @@ -334,7 +326,8 @@ def test_spmm_sputnik(device):
B, L, K = 8, 30, 32
prob = 0.5

a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
mask = torch.rand(B, L, L, device=device) > prob
a = torch.rand(B, L, L, device=device) * mask

b = torch.rand(B, L, K, device=device)

Expand All @@ -359,7 +352,8 @@ def test_spmm_sputnik_backward(device):
B, M, L, K = 8, 16, 30, 32
prob = 0.5

a = _create_random_sparsity(torch.rand(B, M, L, device=device), prob)
mask = torch.rand(B, M, L, device=device) > prob
a = torch.rand(B, M, L, device=device) * mask

b = torch.rand(B, L, K, device=device)
b.requires_grad_(True)
Expand Down Expand Up @@ -391,7 +385,8 @@ def test_csr_transpose():
prob = 0.5
device = torch.device("cuda")

a = _create_random_sparsity(torch.rand(B, L, K, device=device), prob)
mask = torch.rand(B, L, K, device=device) > prob
a = torch.rand(B, L, K, device=device) * mask

a_csr = xformers.components.attention.core.SparseCS(a, device)

Expand Down
9 changes: 6 additions & 3 deletions xformers/sparse/csr_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,19 @@ def _masked_matmul(cls, a, b, mask):
row_indices = mask.__row_indices
row_offsets = mask.__row_offsets
column_indices = mask.__column_indices
a = a.contiguous()
values = mask._SparseCSRTensor__values
out = _csr_ops._sddmm.apply(
a,
a.contiguous(),
b.transpose(-2, -1).contiguous(),
row_indices,
row_offsets,
column_indices,
mask.__transp_info,
)
# TODO add bias here
if values.dtype == torch.bool:
out = torch.where(values, out, float("-inf"))
else:
out = out + values
return cls._wrap(
mask.shape,
out,
Expand Down
11 changes: 5 additions & 6 deletions xformers/sparse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,11 @@ def _dense_to_sparse(matrix, device):


def _round_nnz(mask, divisible_by=4):
nonzero = torch.where(mask)
nnz = nonzero[0].shape[0]
nonzero = tuple(n[: (nnz - nnz % divisible_by)] for n in nonzero)
nm = torch.zeros_like(mask)
nm[nonzero] = True
return nm
nnz = torch.count_nonzero(mask)
cunz = torch.cumsum(~mask.flatten(), dim=0)
flip = cunz <= (-nnz) % divisible_by

return torch.logical_or(mask, flip.reshape_as(mask))


def _dense3d_to_sparse(matrix, device):
Expand Down