@@ -66,21 +66,14 @@ def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
6666 raise ValueError ("Last dimension must be divisible by 4 for 2:4 sparsity." )
6767
6868 full_tensor = torch .randn (shape , dtype = dtype , device = device )
69- mask = torch .zeros_like (full_tensor , dtype = torch .bool )
70-
7169 group_count = shape [- 1 ] // 4
7270 group_shape = shape [:- 1 ] + (group_count , 4 )
7371
74- reshaped = full_tensor .view (* group_shape )
75-
76- for idx in range (reshaped .numel () // 4 ):
77- flat_idx = torch .randint (0 , 4 , (2 ,), dtype = torch .int64 )
78- while flat_idx [0 ] == flat_idx [1 ]:
79- flat_idx [1 ] = torch .randint (0 , 4 , (1 ,), dtype = torch .int64 )
80- i = idx // group_count
81- j = idx % group_count
82- mask .view (* group_shape )[i , j , flat_idx [0 ]] = True
83- mask .view (* group_shape )[i , j , flat_idx [1 ]] = True
72+ rand_vals = torch .rand (group_shape , device = device )
73+ topk_indices = rand_vals .topk (k = 2 , dim = - 1 ).indices
74+ mask = torch .zeros (group_shape , dtype = torch .bool , device = device )
75+ mask .scatter_ (- 1 , topk_indices , True )
76+ mask = mask .view (shape )
8477
8578 sparse_tensor = full_tensor * mask
8679 return sparse_tensor
0 commit comments