Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _round_nnz #1206

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Test with nnz non-divisible by 4
francois-rozet committed Jan 31, 2025
commit 36ae874a37326d7ac7e3ca9176cc6a7bafa885fe
35 changes: 15 additions & 20 deletions tests/test_custom_ops.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
from xformers.components.attention._sputnik_sparse import _csr_to_coo
from xformers.components.attention.core import (
_broadcast_batch,
_create_random_sparsity,
_sparse_bmm,
)

@@ -142,9 +141,7 @@ def test_sddmm_sputnik(device):
prob = 0.5
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device).transpose(-2, -1)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask = torch.rand(B, L, M, device=device) > prob

mask_csr = xformers.components.attention.core.SparseCS(mask, device)

@@ -173,9 +170,7 @@ def test_sddmm_csr(L, M, K, prob):
B = 8
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask = torch.rand(B, L, M, device=device) > prob

mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
@@ -193,7 +188,7 @@ def test_sddmm_csr(L, M, K, prob):


@cuda_only
@pytest.mark.parametrize("nnz", [0, 4, 16, 20, 36])
@pytest.mark.parametrize("nnz", [0, 1, 4, 16, 25, 36])
def test_sddmm_csr_per_nnz(nnz):
device = torch.device("cuda")
B = 8
@@ -226,13 +221,10 @@ def test_sddmm_csr_per_nnz(nnz):
@pytest.mark.parametrize("L", [30, 17])
def test_sddmm_coo(L, M, K, prob):
device = torch.device("cuda")
# TODO add more checks for different nnz
B = 8
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask = torch.rand(B, L, M, device=device) > prob

mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
@@ -263,9 +255,7 @@ def test_sddmm_sputnik_backward(device):
if not contiguous:
a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask = torch.rand(B, L, M, device=device) > prob

mask_csr = xformers.components.attention.core.SparseCS(mask, device)

@@ -289,7 +279,8 @@ def test_sddmm_sputnik_backward(device):
def test_sparse_softmax_sputnik(device):
B, L = 8, 30
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
mask = torch.rand(B, L, L, device=device) > prob
a = torch.rand(B, L, L, device=device) * mask

a_csr = xformers.components.attention.core.SparseCS(a, device)

@@ -311,7 +302,8 @@ def test_sparse_softmax_sputnik(device):
def test_sparse_softmax_sputnik_backward(device):
B, L = 8, 30
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
mask = torch.rand(B, L, L, device=device) > prob
a = torch.rand(B, L, L, device=device) * mask

a_csr = xformers.components.attention.core.SparseCS(a, device)

@@ -334,7 +326,8 @@ def test_spmm_sputnik(device):
B, L, K = 8, 30, 32
prob = 0.5

a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
mask = torch.rand(B, L, L, device=device) > prob
a = torch.rand(B, L, L, device=device) * mask

b = torch.rand(B, L, K, device=device)

@@ -359,7 +352,8 @@ def test_spmm_sputnik_backward(device):
B, M, L, K = 8, 16, 30, 32
prob = 0.5

a = _create_random_sparsity(torch.rand(B, M, L, device=device), prob)
mask = torch.rand(B, M, L, device=device) > prob
a = torch.rand(B, M, L, device=device) * mask

b = torch.rand(B, L, K, device=device)
b.requires_grad_(True)
@@ -391,7 +385,8 @@ def test_csr_transpose():
prob = 0.5
device = torch.device("cuda")

a = _create_random_sparsity(torch.rand(B, L, K, device=device), prob)
mask = torch.rand(B, L, K, device=device) > prob
a = torch.rand(B, L, K, device=device) * mask

a_csr = xformers.components.attention.core.SparseCS(a, device)