Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Refactor Softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
daemyung committed Sep 8, 2023
1 parent 5941c4e commit 33251bc
Show file tree
Hide file tree
Showing 26 changed files with 227 additions and 269 deletions.
38 changes: 13 additions & 25 deletions benchmarks/benchmark_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,31 @@
import trident


@util.report(
"softmax forward",
["vec_sz"],
[256 * i for i in range(1, 21)],
{"num_bt": 32},
)
def bench_softmax_forward(num_bt, vec_sz, backend):
inp = torch.randn(num_bt, vec_sz, device="cuda")
@util.report("softmax forward", ["x_size"], [2048 * i for i in range(1, 11)], {"y_size": 16})
def bench_softmax_forward(y_size, x_size, backend):
input = torch.randn(y_size, x_size, device="cuda")

if backend == "torch":
return triton.testing.do_bench_cudagraph(lambda: torch.softmax(inp, 1))
return triton.testing.do_bench_cudagraph(lambda: torch.softmax(input, 1))
else:
return triton.testing.do_bench_cudagraph(lambda: trident.function.softmax(inp, 1))
return triton.testing.do_bench_cudagraph(lambda: trident.function.softmax(input, 1))


@util.report(
"softmax backward",
["vec_sz"],
[256 * i for i in range(1, 21)],
{"num_bt": 32},
)
def bench_softmax_backward(num_bt, vec_sz, backend):
inp = torch.randn(num_bt, vec_sz, device="cuda", requires_grad=True)
@util.report("softmax backward", ["x_size"], [2048 * i for i in range(1, 11)], {"y_size": 16})
def bench_softmax_backward(y_size, x_size, backend):
input = torch.randn(y_size, x_size, device="cuda", requires_grad=True)
grad_output = torch.rand_like(input)

if backend == "torch":
lyr = torch.nn.Softmax(1)
output = torch.softmax(input, 1)
else:
lyr = trident.Softmax(1)
output = trident.function.softmax(input, 1)

out = lyr.forward(inp)
grad_out = torch.ones_like(inp)

return triton.testing.do_bench_cudagraph(lambda: out.backward(grad_out, retain_graph=True))
return triton.testing.do_bench_cudagraph(lambda: output.backward(grad_output, retain_graph=True))


def run_benchmark(mode, show_plots):
if mode == "forward":
bench_softmax_forward.run(print_data=True, show_plots=show_plots)
else:
raise NotImplementedError("The backward isn't implemented.")
bench_softmax_backward.run(print_data=True, show_plots=show_plots)
6 changes: 3 additions & 3 deletions tests/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from tests import util


@pytest.mark.parametrize("y_size, x_size, dim", [(5, 32, 0), (2, 3000, 1)])
@pytest.mark.parametrize("y_size, x_size, dim", [(2, 512, 0), (3, 1000, 1)])
def test_forward(y_size, x_size, dim, device):
input = torch.randn(y_size, x_size, device=device)

assert util.equal(torch.nn.functional.softmax(input, dim), trident.function.softmax(input, dim))


@pytest.mark.parametrize("y_size, x_size, dim", [(300, 500, 0), (3, 7000, 1)])
@pytest.mark.parametrize("y_size, x_size, dim", [(3, 1000, 0), (2, 512, 1)])
def test_backward(y_size, x_size, dim, device):
input = torch.randn(y_size, x_size, device=device)
target = torch.randn(y_size, x_size, device=device)
Expand All @@ -35,7 +35,7 @@ def train(func, dim):
i = input.clone()
i.requires_grad = True
func(i, dim).backward(target, retain_graph=True)
return [i.grad]
return (i.grad,)

(x,) = train(torch.nn.functional.softmax, dim)
(a,) = train(trident.function.softmax, dim)
Expand Down
8 changes: 7 additions & 1 deletion trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def argmax(input: torch.Tensor, dim: int):
return operation.Argmax.apply(input, dim)


def batch_norm(input, running_mean=None, running_var=None, eps=1e-05, training=False):
def batch_norm(
input: torch.Tensor,
running_mean: torch.Tensor = None,
running_var: torch.Tensor = None,
eps: float = 1e-05,
training: bool = False,
):
"""
Applies Batch Normalization for last certain number of dimensions.
Expand Down
Loading

0 comments on commit 33251bc

Please sign in to comment.