diff --git a/benchmarks/benchmark_softmax.py b/benchmarks/benchmark_softmax.py index dc01fa4f..5996d158 100644 --- a/benchmarks/benchmark_softmax.py +++ b/benchmarks/benchmark_softmax.py @@ -19,43 +19,31 @@ import trident -@util.report( - "softmax forward", - ["vec_sz"], - [256 * i for i in range(1, 21)], - {"num_bt": 32}, -) -def bench_softmax_forward(num_bt, vec_sz, backend): - inp = torch.randn(num_bt, vec_sz, device="cuda") +@util.report("softmax forward", ["x_size"], [2048 * i for i in range(1, 11)], {"y_size": 16}) +def bench_softmax_forward(y_size, x_size, backend): + input = torch.randn(y_size, x_size, device="cuda") if backend == "torch": - return triton.testing.do_bench_cudagraph(lambda: torch.softmax(inp, 1)) + return triton.testing.do_bench_cudagraph(lambda: torch.softmax(input, 1)) else: - return triton.testing.do_bench_cudagraph(lambda: trident.function.softmax(inp, 1)) + return triton.testing.do_bench_cudagraph(lambda: trident.function.softmax(input, 1)) -@util.report( - "softmax backward", - ["vec_sz"], - [256 * i for i in range(1, 21)], - {"num_bt": 32}, -) -def bench_softmax_backward(num_bt, vec_sz, backend): - inp = torch.randn(num_bt, vec_sz, device="cuda", requires_grad=True) +@util.report("softmax backward", ["x_size"], [2048 * i for i in range(1, 11)], {"y_size": 16}) +def bench_softmax_backward(y_size, x_size, backend): + input = torch.randn(y_size, x_size, device="cuda", requires_grad=True) + grad_output = torch.rand_like(input) if backend == "torch": - lyr = torch.nn.Softmax(1) + output = torch.softmax(input, 1) else: - lyr = trident.Softmax(1) + output = trident.function.softmax(input, 1) - out = lyr.forward(inp) - grad_out = torch.ones_like(inp) - - return triton.testing.do_bench_cudagraph(lambda: out.backward(grad_out, retain_graph=True)) + return triton.testing.do_bench_cudagraph(lambda: output.backward(grad_output, retain_graph=True)) def run_benchmark(mode, show_plots): if mode == "forward": bench_softmax_forward.run(print_data=True, show_plots=show_plots) else: - raise NotImplementedError("The backward isn't implemented.") + bench_softmax_backward.run(print_data=True, show_plots=show_plots) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index 2ac9d191..22abf973 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -19,14 +19,14 @@ from tests import util -@pytest.mark.parametrize("y_size, x_size, dim", [(5, 32, 0), (2, 3000, 1)]) +@pytest.mark.parametrize("y_size, x_size, dim", [(2, 512, 0), (3, 1000, 1)]) def test_forward(y_size, x_size, dim, device): input = torch.randn(y_size, x_size, device=device) assert util.equal(torch.nn.functional.softmax(input, dim), trident.function.softmax(input, dim)) -@pytest.mark.parametrize("y_size, x_size, dim", [(300, 500, 0), (3, 7000, 1)]) +@pytest.mark.parametrize("y_size, x_size, dim", [(3, 1000, 0), (2, 512, 1)]) def test_backward(y_size, x_size, dim, device): input = torch.randn(y_size, x_size, device=device) target = torch.randn(y_size, x_size, device=device) @@ -35,7 +35,7 @@ def train(func, dim): i = input.clone() i.requires_grad = True func(i, dim).backward(target, retain_graph=True) - return [i.grad] + return (i.grad,) (x,) = train(torch.nn.functional.softmax, dim) (a,) = train(trident.function.softmax, dim) diff --git a/trident/function/function.py b/trident/function/function.py index 293ef0bd..07b7c055 100644 --- a/trident/function/function.py +++ b/trident/function/function.py @@ -35,7 +35,13 @@ def argmax(input: torch.Tensor, dim: int): return operation.Argmax.apply(input, dim) -def batch_norm(input, running_mean=None, running_var=None, eps=1e-05, training=False): +def batch_norm( + input: torch.Tensor, + running_mean: torch.Tensor = None, + running_var: torch.Tensor = None, + eps: float = 1e-05, + training: bool = False, +): """ Applies Batch Normalization for last certain number of dimensions. diff --git a/trident/kernel/softmax.py b/trident/kernel/softmax.py index dae032d2..75271268 100644 --- a/trident/kernel/softmax.py +++ b/trident/kernel/softmax.py @@ -20,226 +20,184 @@ def softmax_configs(): configs = [] - for num_warps in [4, 8, 16]: - for block_size in [128, 256, 512, 1024, 2048]: - config = triton.Config( - {"block_size": block_size}, - num_warps=num_warps, - ) - configs.append(config) + for x_block_size in [256, 512, 1024, 2048]: + for num_warps in [4, 8, 16]: + configs.append(triton.Config({"x_block_size": x_block_size}, num_warps=num_warps)) return configs class Softmax: @staticmethod - @triton.autotune( - configs=softmax_configs(), - key=["y_size", "x_size", "dim"], - ) + @triton.autotune(softmax_configs(), ["x_size"]) @triton.jit def forward( output_ptr: tl.tensor, input_ptr: tl.tensor, - y_size: int, - x_size: int, - dim: tl.constexpr, - block_size: tl.constexpr, + y_size: tl.int32, + x_size: tl.int32, + y_stride: tl.int32, + x_stride: tl.int32, dtype: tl.constexpr, + x_block_size: tl.constexpr, ): - offset = tl.program_id(0) - - if dim == 0: - input_block_ptr = tl.make_block_ptr( - input_ptr, - shape=(x_size, y_size), - strides=(1, x_size), - offsets=(offset, 0), - block_shape=(1, block_size), - order=(0, 1), - ) - size_along_dim = y_size - else: - input_block_ptr = tl.make_block_ptr( - input_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(offset, 0), - block_shape=(1, block_size), - order=(1, 0), - ) - size_along_dim = x_size - - max = tl.full((1, block_size), -float("inf"), tl.float32) - sum = tl.zeros((1, block_size), tl.float32) - - for block_offset in range(0, size_along_dim, block_size): - input = tl.load(input_block_ptr, boundary_check=(1,)).to(tl.float32) - condition = tl.arange(0, block_size) + block_offset < size_along_dim - input = tl.where(condition, input, -float("inf")) - peak = tl.maximum(max, input) - peak = tl.where(condition, peak, 0) - sum = sum * tl.exp(max - peak) + tl.exp(input - peak) - max = peak - input_block_ptr = tl.advance(input_block_ptr, (0, block_size)) - - max, sum = tl.reduce((max, sum), 1, language.combine_softmax) - - if dim == 0: - input_block_ptr = tl.make_block_ptr( - input_ptr, - shape=(x_size, y_size), - strides=(1, x_size), - offsets=(offset, 0), - block_shape=(1, block_size), - order=(0, 1), - ) - output_block_ptr = tl.make_block_ptr( - output_ptr, - shape=(x_size, y_size), - strides=(1, x_size), - offsets=(offset, 0), - block_shape=(1, block_size), - order=(0, 1), - ) - else: - input_block_ptr = tl.make_block_ptr( - input_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(offset, 0), - block_shape=(1, block_size), - order=(1, 0), - ) - output_block_ptr = tl.make_block_ptr( - output_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(offset, 0), - block_shape=(1, block_size), - order=(1, 0), - ) - - for _ in range(0, size_along_dim, block_size): - input = tl.load(input_block_ptr, boundary_check=(1,)).to(tl.float32) - output = tl.exp(input - max) / sum - tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,)) - input_block_ptr = tl.advance(input_block_ptr, (0, block_size)) - output_block_ptr = tl.advance(output_block_ptr, (0, block_size)) - - @staticmethod - @triton.autotune( - configs=softmax_configs(), - key=["y_size", "x_size"], - ) - @triton.jit - def backward_delta( - delta_ptr: tl.tensor, - grad_output_ptr: tl.tensor, - output_ptr: tl.tensor, - x_size: int, - y_size: int, - x_stride: int, - y_stride: int, - block_size: tl.constexpr, - ): - offset = tl.program_id(0) + y_offset = tl.program_id(0) - delta_block_ptr = tl.make_block_ptr( - delta_ptr, - shape=(y_size,), - strides=(1,), - offsets=(offset,), - block_shape=(1,), - order=(0,), - ) - grad_output_block_ptr = tl.make_block_ptr( - grad_output_ptr, + output_block_ptr = tl.make_block_ptr( + output_ptr, shape=(y_size, x_size), strides=(y_stride, x_stride), - offsets=(offset, 0), - block_shape=(1, block_size), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), order=(1, 0), ) - output_block_ptr = tl.make_block_ptr( - output_ptr, + input_block_ptr = tl.make_block_ptr( + input_ptr, shape=(y_size, x_size), strides=(y_stride, x_stride), - offsets=(offset, 0), - block_shape=(1, block_size), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), order=(1, 0), ) - delta = tl.zeros((1, block_size), tl.float32) + max = tl.full((1, x_block_size), -float("inf"), tl.float32) + sum = tl.zeros((1, x_block_size), tl.float32) - for _ in range(0, x_size, block_size): - output = tl.load(output_block_ptr, boundary_check=(1,), padding_option="zero").to(tl.float32) - grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero").to(tl.float32) - delta += output * grad_output - output_block_ptr = tl.advance(output_block_ptr, (0, block_size)) - grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, block_size)) + for x_offset in range(0, x_size, x_block_size): + input = tl.load(input_block_ptr, boundary_check=(1,)) + condition = tl.arange(0, x_block_size) + x_offset < x_size + input = tl.where(condition, input, -float("inf")) + peak = tl.where(condition, tl.maximum(max, input), 0) + sum = sum * tl.math.fast_expf(max - peak) + tl.math.fast_expf(input - peak) + max = peak + input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size)) + + max, sum = tl.reduce((max, sum), 1, language.combine_softmax) - tl.store(delta_block_ptr, tl.sum(delta, 1)) + input_block_ptr = tl.make_block_ptr( + input_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + + for x_offset in range(0, x_size, x_block_size): + input = tl.load(input_block_ptr, boundary_check=(1,)) + output = tl.math.fast_expf(input - max) / sum + tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,)) + output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size)) + input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size)) @staticmethod - @triton.autotune( - configs=softmax_configs(), - key=["y_size", "x_size"], - ) + @triton.autotune(softmax_configs(), ["x_size"]) @triton.jit def backward( grad_input_ptr: tl.tensor, grad_output_ptr: tl.tensor, output_ptr: tl.tensor, delta_ptr: tl.tensor, - x_size: int, - y_size: int, - x_stride: int, - y_stride: int, - block_size: tl.constexpr, + y_size: tl.int32, + x_size: tl.int32, + y_stride: tl.int32, + x_stride: tl.int32, dtype: tl.constexpr, + x_block_size: tl.constexpr, ): - offset = tl.program_id(0) + y_offset = tl.program_id(0) grad_input_block_ptr = tl.make_block_ptr( grad_input_ptr, shape=(y_size, x_size), strides=(y_stride, x_stride), - offsets=(offset, 0), - block_shape=(1, block_size), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), order=(1, 0), ) grad_output_block_ptr = tl.make_block_ptr( grad_output_ptr, shape=(y_size, x_size), strides=(y_stride, x_stride), - offsets=(offset, 0), - block_shape=(1, block_size), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), order=(1, 0), ) output_block_ptr = tl.make_block_ptr( output_ptr, shape=(y_size, x_size), strides=(y_stride, x_stride), - offsets=(offset, 0), - block_shape=(1, block_size), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), order=(1, 0), ) delta_block_ptr = tl.make_block_ptr( delta_ptr, shape=(y_size,), strides=(1,), - offsets=(offset,), + offsets=(y_offset,), block_shape=(1,), order=(0,), ) - grad_input = tl.zeros((1, block_size), tl.float32) - for block_offset in range(0, x_size, block_size): - output = tl.load(output_block_ptr, boundary_check=(1,), padding_option="zero").to(tl.float32) - grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero").to(tl.float32) - delta = tl.load(delta_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32) + for x_offset in range(0, x_size, x_block_size): + output = tl.load(output_block_ptr, boundary_check=(1,)) + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,)) + delta = tl.load(delta_block_ptr, boundary_check=(0,)) grad_input = output * (grad_output - delta) tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,)) - output_block_ptr = tl.advance(output_block_ptr, (0, block_size)) - grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, block_size)) - grad_input_block_ptr = tl.advance(grad_input_block_ptr, (0, block_size)) + grad_input_block_ptr = tl.advance(grad_input_block_ptr, (0, x_block_size)) + grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, x_block_size)) + output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size)) + + @staticmethod + @triton.autotune(softmax_configs(), ["x_size"]) + @triton.jit + def backward_delta( + delta_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + output_ptr: tl.tensor, + y_size: tl.int32, + x_size: tl.int32, + y_stride: tl.int32, + x_stride: tl.int32, + dtype: tl.constexpr, + x_block_size: tl.constexpr, + ): + y_offset = tl.program_id(0) + + delta_block_ptr = tl.make_block_ptr( + delta_ptr, + shape=(y_size,), + strides=(1,), + offsets=(y_offset,), + block_shape=(1,), + order=(0,), + ) + grad_output_block_ptr = tl.make_block_ptr( + grad_output_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + output_block_ptr = tl.make_block_ptr( + output_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + + delta = tl.zeros((1, x_block_size), dtype) + + for _ in range(0, x_size, x_block_size): + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero") + output = tl.load(output_block_ptr, boundary_check=(1,)) + delta += grad_output * output + output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size)) + grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, x_block_size)) + + delta = tl.sum(delta, 1) + tl.store(delta_block_ptr, delta) diff --git a/trident/language/combine.py b/trident/language/combine.py index 129e414c..83f2ca9a 100644 --- a/trident/language/combine.py +++ b/trident/language/combine.py @@ -29,5 +29,5 @@ def combine_welford(m2_a, mean_a, count_a, m2_b, mean_b, count_b): @triton.jit def combine_softmax(max_a: tl.tensor, sum_a: tl.tensor, max_b: tl.tensor, sum_b: tl.tensor): max = tl.math.max(max_a, max_b) - sum = sum_a * tl.math.exp(max_a - max) + sum_b * tl.math.exp(max_b - max) + sum = sum_a * tl.math.fast_expf(max_a - max) + sum_b * tl.math.fast_expf(max_b - max) return max, sum diff --git a/trident/module.py b/trident/module.py index 66ea394b..5b37b707 100644 --- a/trident/module.py +++ b/trident/module.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from typing import Tuple import torch @@ -73,11 +74,11 @@ def __view(x): class BatchNorm1d(torch.nn.Module): def __init__( self, - num_features, - eps=1e-05, - momentum=0.1, - affine=True, - track_running_stats=True, + num_features: Tuple[int, ...], + eps: float = 1e-05, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, device=None, dtype=None, ): @@ -95,17 +96,16 @@ def __init__( """ super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats - self.device = device - self.dtype = dtype if affine: - self.weight = torch.nn.Parameter(torch.empty(num_features, device=device, dtype=dtype).fill_(1)) - self.bias = torch.nn.Parameter(torch.zeros(num_features, device=device, dtype=dtype)) + self.weight = torch.nn.Parameter(torch.empty(num_features, **factory_kwargs).fill_(1)) + self.bias = torch.nn.Parameter(torch.zeros(num_features, **factory_kwargs)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) @@ -117,7 +117,7 @@ def __init__( self.register_parameter("running_mean", None) self.register_parameter("running_var", None) - def forward(self, input): + def forward(self, input: torch.Tensor): """ Applies Batch Normalization to an input. @@ -220,7 +220,7 @@ def __init__(self, dim: int = 1, eps: float = 1e-8): self.dim = dim self.eps = eps - def forward(self, x1, x2): + def forward(self, x1: torch.Tensor, x2: torch.Tensor): """ Applies Cosine similarity to inputs. @@ -259,7 +259,7 @@ def __init__(self, p: float = 0.5): self.p = p - def forward(self, input): + def forward(self, input: torch.Tensor): """ Applies Dropout to an input. @@ -381,10 +381,10 @@ def extra_repr(self): class GroupNorm(torch.nn.Module): def __init__( self, - num_groups, - num_channels, - eps=1e-05, - affine=True, + num_groups: int, + num_channels: int, + eps: float = 1e-05, + affine: bool = True, device=None, dtype=None, ): @@ -418,7 +418,7 @@ def __init__( self.reset_parameters() - def forward(self, input): + def forward(self, input: torch.Tensor): """ Applies Group Normalization to an input. @@ -545,11 +545,11 @@ def extra_repr(self): class InstanceNorm2d(torch.nn.Module): def __init__( self, - num_features, - eps=1e-5, - momentum=0.1, - affine=False, - track_running_stats=False, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, device=None, dtype=None, ): @@ -586,7 +586,7 @@ def __init__( self.reset_parameters() - def forward(self, input): + def forward(self, input: torch.Tensor): """ Applies Instance Normalization to an input. @@ -646,9 +646,9 @@ def extra_repr(self): class LayerNorm(torch.nn.Module): def __init__( self, - normalized_shape, - eps=1e-05, - elementwise_affine=True, + normalized_shape: Tuple[int, ...], + eps: float = 1e-05, + elementwise_affine: bool = True, device=None, dtype=None, ): @@ -677,7 +677,7 @@ def __init__( self.reset_parameters() - def forward(self, input): + def forward(self, input: torch.Tensor): """ Applies Layer Normalization to an input. @@ -880,7 +880,7 @@ def extra_repr(self): class Mean(torch.nn.Module): - def __init__(self, dim=None): + def __init__(self, dim: int = None): """ Computes the mean along the specified dimension in an input. @@ -891,7 +891,7 @@ def __init__(self, dim=None): self.dim = dim - def forward(self, input): + def forward(self, input: torch.Tensor): """ Computes the mean along the specified dimension in an input. @@ -980,7 +980,15 @@ def extra_repr(self): class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, p=-1.0, eps=1e-05, bias=False, device=None, dtype=None): + def __init__( + self, + normalized_shape: Tuple[int, ...], + p: float = -1.0, + eps: float = 1e-05, + bias: bool = False, + device=None, + dtype=None, + ): """ Applies Root Mean Square Layer Normalization to an input. @@ -1005,7 +1013,7 @@ def __init__(self, normalized_shape, p=-1.0, eps=1e-05, bias=False, device=None, self.reset_parameters() - def forward(self, input): + def forward(self, input: torch.Tensor): """ Applies Root Mean Square Layer Normalization to an input. @@ -1043,7 +1051,7 @@ def extra_repr(self): class ShiftGELU(torch.nn.Module): - def __init__(self, num_features: torch.int, device=None, dtype=None): + def __init__(self, num_features: int, device=None, dtype=None): """ Applies shift and the Gaussian Error Linear Units to an input. @@ -1057,7 +1065,7 @@ def __init__(self, num_features: torch.int, device=None, dtype=None): self.reset_parameters() - def forward(self, input): + def forward(self, input: torch.Tensor): """ Applies shift and the Gaussian Error Linear Units to an input. @@ -1092,7 +1100,7 @@ def __init__(self): """ super().__init__() - def forward(self, input): + def forward(self, input: torch.Tensor): """ Applies the Sigmoid Linear Unit to an input. @@ -1149,7 +1157,7 @@ def extra_repr(self): class Sum(torch.nn.Module): - def __init__(self, dim=None): + def __init__(self, dim: int = None): """ Computes the sum along the specified dimension in an input. @@ -1160,7 +1168,7 @@ def __init__(self, dim=None): self.dim = dim - def forward(self, input): + def forward(self, input: torch.Tensor): """ Computes the sum along the specified dimension in an input. @@ -1183,7 +1191,7 @@ def extra_repr(self): class Var(torch.nn.Module): - def __init__(self, dim=None, correction=1): + def __init__(self, dim: int = None, correction: int = 1): """ Computes the variance along the specified dimension in an input. @@ -1196,7 +1204,7 @@ def __init__(self, dim=None, correction=1): self.dim = dim self.correction = correction - def forward(self, input): + def forward(self, input: torch.Tensor): """ Computes the variance along the specified dimension in an input. diff --git a/trident/operation/attention.py b/trident/operation/attention.py index b1effcbd..2f16f53f 100644 --- a/trident/operation/attention.py +++ b/trident/operation/attention.py @@ -132,13 +132,7 @@ def __backward( num_batches, num_heads, y_size, x_size = output.shape kernel.Softmax.backward_delta[(num_batches * num_heads * y_size,)]( - delta, - output, - grad_output, - x_size=x_size, - y_size=num_batches * num_heads * y_size, - x_stride=1, - y_stride=x_size, + delta, output, grad_output, num_batches * num_heads * y_size, x_size, x_size, 1, util.dtype(delta.dtype) ) kernel.Attention.backward[(grid[1],)]( diff --git a/trident/operation/cosine_similarity.py b/trident/operation/cosine_similarity.py index 16931cff..570ac872 100644 --- a/trident/operation/cosine_similarity.py +++ b/trident/operation/cosine_similarity.py @@ -31,7 +31,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): grad_output = grad_outputs[0] x1, x2, denominator, numerator = ctx.saved_tensors return CosineSimilarity.__backward(grad_output, x1, x2, denominator, numerator, ctx.dim) diff --git a/trident/operation/dropout.py b/trident/operation/dropout.py index 1e1763ba..d5527d98 100644 --- a/trident/operation/dropout.py +++ b/trident/operation/dropout.py @@ -32,7 +32,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs (input, output) = ctx.saved_tensors return Dropout.__backward(grad_output, input, output, ctx.p) diff --git a/trident/operation/geglu.py b/trident/operation/geglu.py index b725e058..51488433 100644 --- a/trident/operation/geglu.py +++ b/trident/operation/geglu.py @@ -32,7 +32,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs input, weight, bias, state_gate = ctx.saved_tensors return GEGLU.__backward(grad_output, input, weight, bias, state_gate, ctx.use_accelerator) diff --git a/trident/operation/gelu.py b/trident/operation/gelu.py index 613b811c..f5f85f92 100644 --- a/trident/operation/gelu.py +++ b/trident/operation/gelu.py @@ -30,7 +30,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return GELU.__forward(input) @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs (input,) = ctx.saved_tensors return GELU.__backward(grad_output, input) diff --git a/trident/operation/group_norm.py b/trident/operation/group_norm.py index 892514f5..cb4c6b0c 100644 --- a/trident/operation/group_norm.py +++ b/trident/operation/group_norm.py @@ -32,7 +32,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs input, weight, bias, rstd, mean = ctx.saved_tensors return GroupNorm.__backward(grad_output, input, weight, bias, rstd, mean, ctx.num_groups) diff --git a/trident/operation/instance_norm.py b/trident/operation/instance_norm.py index 7da6511a..993519b1 100644 --- a/trident/operation/instance_norm.py +++ b/trident/operation/instance_norm.py @@ -35,7 +35,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): input, running_mean, running_var, weight, mean, var, weight, bias = ctx.saved_tensors (grad_output,) = grad_outputs return InstanceNorm.__backward( diff --git a/trident/operation/layer_norm.py b/trident/operation/layer_norm.py index cf521ade..f8c59104 100644 --- a/trident/operation/layer_norm.py +++ b/trident/operation/layer_norm.py @@ -33,7 +33,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs input, weight, bias, rstd, mean = ctx.saved_tensors return LayerNorm.__backward(grad_output, input, ctx.normalized_shape, weight, bias, rstd, mean) diff --git a/trident/operation/leaky_relu.py b/trident/operation/leaky_relu.py index 56883186..41a0e550 100644 --- a/trident/operation/leaky_relu.py +++ b/trident/operation/leaky_relu.py @@ -31,7 +31,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return LeakyReLU.__forward(input, negative_slope) @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): grad_output = grad_outputs[0] (input,) = ctx.saved_tensors return LeakyReLU.__backward(grad_output, input, ctx.negative_slope) diff --git a/trident/operation/linear.py b/trident/operation/linear.py index 1dd3640e..b6457d32 100644 --- a/trident/operation/linear.py +++ b/trident/operation/linear.py @@ -32,7 +32,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs input, weight, bias, output = ctx.saved_tensors return Linear.__backward(grad_output, output, input, weight, bias, ctx.use_accelerator) diff --git a/trident/operation/max.py b/trident/operation/max.py index 2a017574..adce1b7e 100644 --- a/trident/operation/max.py +++ b/trident/operation/max.py @@ -32,7 +32,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output, argmax @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): grad_output, grad_argmax = grad_outputs input, output, argmax = ctx.saved_tensors return Max.__backward(grad_output, input, argmax, ctx.dim) diff --git a/trident/operation/mean.py b/trident/operation/mean.py index 0a23d677..dd7e0df4 100644 --- a/trident/operation/mean.py +++ b/trident/operation/mean.py @@ -31,7 +31,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return Mean.__forward(input, dim) @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (input,) = ctx.saved_tensors (grad_output,) = grad_outputs return Mean.__backward(grad_output, input, ctx.dim) diff --git a/trident/operation/prelu.py b/trident/operation/prelu.py index 692801ce..60225757 100644 --- a/trident/operation/prelu.py +++ b/trident/operation/prelu.py @@ -50,7 +50,7 @@ def grid(meta): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): return PReLU.__backward(grad_outputs[0], *ctx.saved_tensors) @staticmethod diff --git a/trident/operation/relu.py b/trident/operation/relu.py index 465c78f2..774ab9e2 100644 --- a/trident/operation/relu.py +++ b/trident/operation/relu.py @@ -30,7 +30,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return ReLU.__forward(input) @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): grad_output = grad_outputs[0] (input,) = ctx.saved_tensors return ReLU.__backward(grad_output, input) diff --git a/trident/operation/rms_norm.py b/trident/operation/rms_norm.py index 9a2070c0..1d0d2f7c 100644 --- a/trident/operation/rms_norm.py +++ b/trident/operation/rms_norm.py @@ -33,7 +33,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs input, rms, weight, bias = ctx.saved_tensors return RMSNorm.__backward(grad_output, input, ctx.p, rms, weight, bias, ctx.eps) diff --git a/trident/operation/silu.py b/trident/operation/silu.py index f92101b7..db117526 100644 --- a/trident/operation/silu.py +++ b/trident/operation/silu.py @@ -29,7 +29,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return SiLU.__forward(input) @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs (input,) = ctx.saved_tensors return SiLU.__backward(grad_output, input) diff --git a/trident/operation/softmax.py b/trident/operation/softmax.py index 2c8d3f56..2f2a2f5f 100644 --- a/trident/operation/softmax.py +++ b/trident/operation/softmax.py @@ -15,6 +15,7 @@ from typing import Any import torch +import triton from trident import kernel, util @@ -31,62 +32,65 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs (output,) = ctx.saved_tensors return Softmax.__backward(grad_output, output, ctx.dim) @staticmethod def __forward(input: torch.Tensor, dim: int): - y_size, x_size = input.shape + y_size, x_size, y_stride, x_stride = util.size_and_stride(input, dim) output = torch.empty_like(input) def grid(meta): - return (x_size if dim == 0 else y_size,) + return (y_size,) kernel.Softmax.forward[grid]( output, input, y_size, x_size, - dim, - dtype=util.dtype(input.dtype), + y_stride, + x_stride, + util.dtype(output.dtype), ) return output @staticmethod - def __backward(grad_output: torch.Tensor, output: torch.Tensor, dim: int): - factory_kwargs = {"device": grad_output.device} + def __backward(grad_output: torch.Tensor, output: torch.Tensor, dim: torch.int32): + factory_kwargs = {"device": output.device, "dtype": output.dtype} y_size, x_size, y_stride, x_stride = util.size_and_stride(output, dim) + delta = torch.empty(x_size, **factory_kwargs) + grad_input = torch.empty_like(output) def grid(meta): return (y_size,) - delta = torch.empty(x_size, **factory_kwargs) - kernel.Softmax.backward_delta[grid]( delta, grad_output, output, - x_size, y_size, - x_stride, + x_size, y_stride, + x_stride, + util.dtype(delta.dtype), ) - grad_input = torch.empty_like(output) + def grid(meta): + return (y_size,) kernel.Softmax.backward[grid]( grad_input, grad_output, output, delta, - x_size, y_size, - x_stride, + x_size, y_stride, - dtype=util.dtype(output.dtype), + x_stride, + util.dtype(output.dtype), ) return grad_input, None diff --git a/trident/operation/sum.py b/trident/operation/sum.py index 0bc391d4..ffa72815 100644 --- a/trident/operation/sum.py +++ b/trident/operation/sum.py @@ -30,7 +30,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return Sum.__forward(input, dim) @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (input,) = ctx.saved_tensors (grad_output,) = grad_outputs return Sum.__backward(grad_output, input, ctx.dim) diff --git a/trident/operation/var.py b/trident/operation/var.py index 00ea2ad5..da8676ba 100644 --- a/trident/operation/var.py +++ b/trident/operation/var.py @@ -32,7 +32,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return Var.__forward(input, dim, correction) @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (input,) = ctx.saved_tensors (grad_output,) = grad_outputs return Var.__backward(grad_output, input, ctx.dim, ctx.correction) diff --git a/trident/operation/var_mean.py b/trident/operation/var_mean.py index a581176a..18dcf654 100644 --- a/trident/operation/var_mean.py +++ b/trident/operation/var_mean.py @@ -33,7 +33,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): return output, mean @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx: Any, *grad_outputs: Any): (input, mean) = ctx.saved_tensors (grad_output, _) = grad_outputs return VarMean.__backward(grad_output, input, mean, ctx.dim, ctx.correction)