From 51ea816f3eb12f03d8e5485dc8fab93606cac636 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Wed, 18 Dec 2024 21:00:12 +0000 Subject: [PATCH 1/2] No in-place ops. --- .../scattermoe/triton_implementation/ops.py | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/cute_kernels/kernels/scattermoe/triton_implementation/ops.py b/cute_kernels/kernels/scattermoe/triton_implementation/ops.py index a9fa705b..a1d351f2 100644 --- a/cute_kernels/kernels/scattermoe/triton_implementation/ops.py +++ b/cute_kernels/kernels/scattermoe/triton_implementation/ops.py @@ -197,31 +197,24 @@ def backward(ctx, grad_out): d_gates = None gates_flat = None gate_fan = 1 - grouped_grad_out = None + grouped_grad_out = torch.zeros_like(output_expanded) else: # calculate gates gradient d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1) gates_flat = gates.flatten() gate_fan = gates.size(1) # print("expanded and grouping") - grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later - - if grouped_out: - grouped_grad_out = grad_out - else: - group( - A=grad_out, - sorted_expert_idxs=sorted_scattered_idxs, - out=grouped_grad_out, - coeff=gates_flat, - fan_out=gate_fan, - ) + grouped_grad_out = torch.zeros_like(output_expanded.flatten(0, 1)) # reuse expanded buffer later + group( + A=grad_out, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_grad_out, + coeff=gates_flat, + fan_out=gate_fan, + ) if grouped_in: grouped_x = x - d_expanded_input = torch.empty( - sorted_expert_idxs.size(0), expert_weights.size(1), device=x.device, dtype=x.dtype - ) else: grouped_x = torch.empty(sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device) group( @@ -231,10 +224,7 @@ def backward(ctx, grad_out): fan_out=k, ) - d_expanded_input = grouped_x - d_weights = torch.zeros_like(expert_weights) - group_bwd_W( DY=grouped_grad_out, X=grouped_x, @@ -243,6 +233,9 @@ def backward(ctx, grad_out): E=expert_weights.size(0), ) + d_expanded_input = torch.empty( + sorted_expert_idxs.size(0), expert_weights.size(1), device=x.device, dtype=x.dtype + ) scatter2scatter( X=grouped_grad_out, W=expert_weights.permute(0, 2, 1), From 96502f849e752d79de9497eabc70d96b122b0f62 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 18 Dec 2024 21:33:56 +0000 Subject: [PATCH 2/2] fix --- cute_kernels/kernels/scattermoe/triton_implementation/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cute_kernels/kernels/scattermoe/triton_implementation/ops.py b/cute_kernels/kernels/scattermoe/triton_implementation/ops.py index 1c7f2916..277f7ce7 100644 --- a/cute_kernels/kernels/scattermoe/triton_implementation/ops.py +++ b/cute_kernels/kernels/scattermoe/triton_implementation/ops.py @@ -198,7 +198,7 @@ def backward(ctx, grad_out): d_gates = None gates_flat = None gate_fan = 1 - grouped_grad_out = torch.zeros_like(output_expanded) + grouped_grad_out = torch.zeros_like(grad_out) else: # calculate gates gradient d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1)