Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Support Torch 1.13
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaehyun An committed Sep 7, 2023
1 parent f174963 commit 04852ea
Show file tree
Hide file tree
Showing 29 changed files with 218 additions and 185 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pytest
pytest-xdist
torch>=2.0.0
torch
matplotlib
pandas
25 changes: 9 additions & 16 deletions trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def cosine_similarity(x1: torch.Tensor, x2: torch.Tensor, dim: int = 1, eps: flo
See cosine similarity for detail.
"""
output, _, _ = operation.CosineSimilarity.apply(x1, x2, dim, eps)
return output
return operation.CosineSimilarity.apply(x1, x2, dim, eps)


def dropout(input, p=0.5, training=True):
Expand All @@ -87,11 +86,10 @@ def geglu(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None,
See GEGLU for details.
"""
if input.dim() == 2:
output, _ = operation.GEGLU.apply(input.view(1, *input.shape), weight, bias, use_accelerator)
output = operation.GEGLU.apply(input.view(1, *input.shape), weight, bias, use_accelerator)
return output.view(output.shape[1:3])
else:
output, _ = operation.GEGLU.apply(input, weight, bias, use_accelerator)
return output
return operation.GEGLU.apply(input, weight, bias, use_accelerator)


def gelu(input: torch.Tensor):
Expand All @@ -109,10 +107,9 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
See GroupNorm for details.
"""
output, _, _ = operation.GroupNorm.apply(
return operation.GroupNorm.apply(
input.view(input.shape[0], input.shape[1], -1), num_groups, weight, bias, eps
)
return output.view(input.shape)
).view(input.shape)


def instance_norm(
Expand All @@ -130,7 +127,7 @@ def instance_norm(
See InstanceNorm2d for details.
"""
output, _, _ = operation.InstanceNorm.apply(
return operation.InstanceNorm.apply(
input,
running_mean,
running_var,
Expand All @@ -140,7 +137,6 @@ def instance_norm(
momentum,
eps,
)
return output


def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
Expand All @@ -149,8 +145,7 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
See LayerNorm for details.
"""
output, _, _ = operation.LayerNorm.apply(input, normalized_shape, weight, bias, eps)
return output
return operation.LayerNorm.apply(input, normalized_shape, weight, bias, eps)


def leaky_relu(input: torch.Tensor, negative_slope: float = 0.01):
Expand Down Expand Up @@ -222,8 +217,7 @@ def rms_norm(input: torch.Tensor, p: float, weight: torch.Tensor, bias: torch.Te
See RMSNorm for details.
"""
output, _ = operation.RMSNorm.apply(input.view(-1, input.shape[-1]), p, weight, bias, eps)
return output.view(input.shape)
return operation.RMSNorm.apply(input.view(-1, input.shape[-1]), p, weight, bias, eps).view(input.shape)


def shift_gelu(input: torch.Tensor, bias: torch.Tensor):
Expand All @@ -232,8 +226,7 @@ def shift_gelu(input: torch.Tensor, bias: torch.Tensor):
See ShiftGELU for details.
"""
output, _ = operation.ShiftGELU.apply(input.view(-1, input.shape[-1]), bias)
return output.view(input.shape)
return operation.ShiftGELU.apply(input.view(-1, input.shape[-1]), bias).view(input.shape)


def silu(input):
Expand Down
2 changes: 1 addition & 1 deletion trident/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def forward(self, input):
self.running_mean = input.mean(axis=0) * self.momentum + self.running_mean * (1 - self.momentum)
self.running_var = input.var(axis=0) * self.momentum + self.running_var * (1 - self.momentum)

return operation.BatchNorm.apply(input, self.weight, self.bias, self.eps)
return operation.BatchNorm.apply(input, self.weight, self.bias, self.eps, None, None)

def extra_repr(self):
"""
Expand Down
11 changes: 6 additions & 5 deletions trident/operation/adaptive_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch

from trident import kernel


class AdaptiveAvgPool2d(torch.autograd.Function):
@staticmethod
def forward(*args, **kwargs):
def forward(ctx: Any, *args: Any, **kwargs: Any):
x, output_size = args
return AdaptiveAvgPool2d.__forward(x, output_size)

@staticmethod
def __forward(x: torch.Tensor, output_size: int):
assert x.is_cuda and x.is_contiguous()

num_batches, num_channels, num_rows, num_cols = x.shape
Expand Down Expand Up @@ -61,7 +66,3 @@ def forward(*args, **kwargs):
)

return y

@staticmethod
def setup_context(ctx, inputs, output):
pass
9 changes: 3 additions & 6 deletions trident/operation/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import triton

Expand All @@ -20,15 +22,10 @@

class Argmax(torch.autograd.Function):
@staticmethod
def forward(*args, **kwargs):
def forward(ctx: Any, *args: Any, **kwargs: Any):
input, dim = args

return Argmax.__forward(input, dim)

@staticmethod
def setup_context(ctx, inputs, output):
pass

@staticmethod
def __forward(input: torch.Tensor, dim: torch.int32):
factory_kwargs = {"device": input.device, "dtype": torch.int64}
Expand Down
79 changes: 61 additions & 18 deletions trident/operation/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from typing import Any, Tuple

import torch
import triton
Expand All @@ -22,8 +22,46 @@

class Attention(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
def forward(ctx: Any, *args: Any, **kwargs: Any):
query, key, value, is_causal, softmax_scale, use_accelerator = args
output, log_sum_exp, grid = Attention.__forward(query, key, value, is_causal, softmax_scale, use_accelerator)

ctx.save_for_backward(query, key, value, output, log_sum_exp)
ctx.grid = grid
ctx.softmax_scale = softmax_scale
ctx.embedding_size = key.shape[-1]
ctx.is_causal = is_causal
ctx.use_accelerator = use_accelerator

return output

@staticmethod
def backward(ctx: Any, *grad_outputs: Any):
(grad_output,) = grad_outputs

query, key, value, output, log_sum_exp = ctx.saved_tensors
grid = ctx.grid
softmax_scale = ctx.softmax_scale
embedding_size = ctx.embedding_size
is_causal = ctx.is_causal
use_accelerator = ctx.use_accelerator

return Attention.__backward(
grad_output,
query,
key,
value,
output,
log_sum_exp,
grid,
softmax_scale,
embedding_size,
is_causal,
use_accelerator,
)

@staticmethod
def __forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -69,18 +107,23 @@ def forward(
num_warps=num_warps,
)

ctx.save_for_backward(query, key, value, output, log_sum_exp)
ctx.grid = grid
ctx.softmax_scale = softmax_scale
ctx.embedding_size = key.shape[-1]
ctx.is_causal = is_causal
ctx.use_accelerator = use_accelerator
return output
return output, log_sum_exp, grid

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor):
def __backward(
grad_output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
log_sum_exp: torch.Tensor,
grid: Tuple,
softmax_scale: float,
embedding_size,
is_causal: bool,
use_accelerator: bool,
):
block_size = 64
query, key, value, output, log_sum_exp = ctx.saved_tensors
grad_output = grad_output.contiguous()
grad_query = torch.zeros_like(query)
grad_key = torch.empty_like(key)
Expand All @@ -98,7 +141,7 @@ def backward(ctx: Any, grad_output: torch.Tensor):
y_stride=x_size,
)

kernel.Attention.backward[(ctx.grid[1],)](
kernel.Attention.backward[(grid[1],)](
grad_query,
grad_key,
grad_value,
Expand All @@ -116,13 +159,13 @@ def backward(ctx: Any, grad_output: torch.Tensor):
key.stride(3),
query.shape[1],
query.shape[2],
ctx.grid[0],
ctx.softmax_scale,
grid[0],
softmax_scale,
m_block_size=block_size,
n_block_size=block_size,
embedding_size=ctx.embedding_size,
is_causal=ctx.is_causal,
use_accelerator=ctx.use_accelerator,
embedding_size=embedding_size,
is_causal=is_causal,
use_accelerator=use_accelerator,
dtype=util.dtype(query.dtype),
num_warps=8,
)
Expand Down
15 changes: 8 additions & 7 deletions trident/operation/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import triton

Expand All @@ -20,21 +22,20 @@

class BatchNorm(torch.autograd.Function):
@staticmethod
def forward(*args, **kwargs):
return BatchNorm.__forward(*args, **kwargs)
def forward(ctx: Any, *args: Any, **kwargs: Any):
input, weight, bias, eps, running_mean, running_var = args

@staticmethod
def setup_context(ctx, inputs, output):
inp, wgt, bis, eps, *_ = inputs
ctx.save_for_backward(inp, wgt, bis)
ctx.save_for_backward(input, weight, bias)
ctx.eps = eps

return BatchNorm.__forward(input, weight, bias, eps, running_mean, running_var)

@staticmethod
def backward(ctx, *grad_outputs):
return BatchNorm.__backward(*grad_outputs, *ctx.saved_tensors, ctx.eps)

@staticmethod
def __forward(inp, wgt, bis, eps, running_mean=None, running_var=None):
def __forward(inp, wgt, bis, eps, running_mean, running_var):
bt_sz, vec_sz = inp.shape

def grid(meta):
Expand Down
11 changes: 5 additions & 6 deletions trident/operation/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import triton

Expand All @@ -20,12 +22,9 @@

class Conv2d(torch.autograd.Function):
@staticmethod
def forward(*args, **kwargs):
return Conv2d.__forward(*args)

@staticmethod
def setup_context(ctx, inputs, output):
pass
def forward(ctx: Any, *args: Any, **kwargs: Any):
input, weight, bias = args
return Conv2d.__forward(input, weight, bias)

@staticmethod
def __forward(inp, wgt, bis):
Expand Down
12 changes: 6 additions & 6 deletions trident/operation/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import triton

Expand All @@ -20,17 +22,15 @@

class CosineSimilarity(torch.autograd.Function):
@staticmethod
def forward(*args, **kwargs):
def forward(ctx: Any, *args: Any, **kwargs: Any):
x1, x2, dim, eps = args
return CosineSimilarity.__forward(x1, x2, dim, eps)
output, denominator, numerator = CosineSimilarity.__forward(x1, x2, dim, eps)

@staticmethod
def setup_context(ctx, inputs, output):
x1, x2, dim, eps = inputs
_, denominator, numerator = output
ctx.save_for_backward(x1, x2, denominator, numerator)
ctx.dim = dim

return output

@staticmethod
def backward(ctx, *grad_outputs):
grad_output = grad_outputs[0]
Expand Down
12 changes: 7 additions & 5 deletions trident/operation/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import triton

Expand All @@ -20,15 +22,15 @@

class Dropout(torch.autograd.Function):
@staticmethod
def forward(*args, **kwargs):
return Dropout.__forward(*args, **kwargs)
def forward(ctx: Any, *args: Any, **kwargs: Any):
input, p = args
output = Dropout.__forward(input, p)

@staticmethod
def setup_context(ctx, inputs, output):
input, p = inputs
ctx.save_for_backward(input, output)
ctx.p = p

return output

@staticmethod
def backward(ctx, *grad_outputs):
(grad_output,) = grad_outputs
Expand Down
Loading

0 comments on commit 04852ea

Please sign in to comment.