Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No in-place ops. #105

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions cute_kernels/kernels/scattermoe/triton_implementation/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,31 +198,24 @@ def backward(ctx, grad_out):
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
grouped_grad_out = torch.zeros_like(grad_out)
else:
# calculate gates gradient
d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
# print("expanded and grouping")
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later

if grouped_out:
grouped_grad_out = grad_out
else:
group(
A=grad_out,
sorted_expert_idxs=sorted_scattered_idxs,
out=grouped_grad_out,
coeff=gates_flat,
fan_out=gate_fan,
)
grouped_grad_out = torch.zeros_like(output_expanded.flatten(0, 1)) # reuse expanded buffer later
group(
A=grad_out,
sorted_expert_idxs=sorted_scattered_idxs,
out=grouped_grad_out,
coeff=gates_flat,
fan_out=gate_fan,
)

if grouped_in:
grouped_x = x
d_expanded_input = torch.empty(
sorted_expert_idxs.size(0), expert_weights.size(1), device=x.device, dtype=x.dtype
)
else:
grouped_x = torch.empty(sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device)
group(
Expand All @@ -232,10 +225,7 @@ def backward(ctx, grad_out):
fan_out=k,
)

d_expanded_input = grouped_x

d_weights = torch.zeros_like(expert_weights)

group_bwd_W(
DY=grouped_grad_out,
X=grouped_x,
Expand All @@ -244,6 +234,9 @@ def backward(ctx, grad_out):
E=expert_weights.size(0),
)

d_expanded_input = torch.empty(
sorted_expert_idxs.size(0), expert_weights.size(1), device=x.device, dtype=x.dtype
)
scatter2scatter(
X=grouped_grad_out,
W=expert_weights.permute(0, 2, 1),
Expand Down
Loading