Skip to content

Commit

Permalink
Replace pad op implementation in conv kernel with torch implementation (
Browse files Browse the repository at this point in the history
#336)

Signed-off-by: nithinsubbiah <[email protected]>
  • Loading branch information
nithinsubbiah authored Oct 25, 2024
1 parent 1c3ed4b commit 6f3f8c7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 28 deletions.
24 changes: 2 additions & 22 deletions sharktank/sharktank/ops/qconv_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings

import torch
import torch.nn.functional as F

from sharktank import kernels

Expand Down Expand Up @@ -119,7 +120,7 @@ def qconv2d_tensor_scaled(
padding = _expand_int_to_2_tuple(padding)
dilation = _expand_int_to_2_tuple(dilation)
extended_padding_list = [item for item in padding for _ in range(2)]
padded_input = _pad_last_2d(input_qs, extended_padding_list)
padded_input = F.pad(input_qs, pad=extended_padding_list)
y_qs = _invoke_conv2d_kernel(
padded_input,
weight_qs,
Expand Down Expand Up @@ -258,27 +259,6 @@ def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dt
return output


def _pad_last_2d(input_tensor, pad_width):
# pad_width should be in the format [pad_left, pad_right, pad_top, pad_bottom]
pad_left, pad_right, pad_top, pad_bottom = pad_width
batch_size, channels, height, width = input_tensor.shape

# Create a new tensor with the desired padded size filled with zeros
padded_height = height + pad_top + pad_bottom
padded_width = width + pad_left + pad_right
padded_tensor = torch.zeros(
(batch_size, channels, padded_height, padded_width),
dtype=input_tensor.dtype,
device=input_tensor.device,
)

# Copy the values from the input tensor to the appropriate location in the padded tensor
padded_tensor[
:, :, pad_top : pad_top + height, pad_left : pad_left + width
] = input_tensor
return padded_tensor


def _flatten_input_scale_offset_channels(d, m):
"""Flattens either a 4d or 0d scale/offset as [N, C, H, W] to 1D.
Expand Down
7 changes: 4 additions & 3 deletions sharktank/tests/kernels/conv_2d_nchw_fchw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from parameterized import parameterized

import torch
import torch.nn.functional as F

from iree.turbine import aot
from sharktank import kernels
from sharktank.ops.qconv_impls import _pad_last_2d


class conv_2d_nchw_fchw_test(unittest.TestCase):
Expand All @@ -36,7 +36,8 @@ def testBS32(self, input_dtype, output_dtype_name, atol, rtol):
inputs = (torch.rand([2, 4, 64, 64]) * 64).to(input_dtype)
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
inputs_pad = _pad_last_2d(inputs, extended_list)
inputs_pad = F.pad(inputs, pad=extended_list)

weights = (torch.rand([8, 4, 3, 3]) * 64).to(input_dtype)
bias = (torch.rand([8]) * 64).to(dtype=output_dtype)
result = kernels.conv_2d_nchw_fchw(
Expand Down Expand Up @@ -68,7 +69,7 @@ def forward(self, a, b, c):
inputs = torch.rand([2, 320, 64, 64]) * 64
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
inputs_pad = _pad_last_2d(inputs, extended_list)
inputs_pad = F.pad(inputs, pad=extended_list)
ep = torch.export.export(
mod,
args=(
Expand Down
6 changes: 3 additions & 3 deletions sharktank/tests/kernels/pooling_nchw_sum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from parameterized import parameterized

import torch
import torch.nn.functional as F

from iree.turbine import aot
from sharktank import kernels
from sharktank.ops.qconv_impls import _pad_last_2d


class pooling_nchw_sum_test(unittest.TestCase):
Expand All @@ -34,7 +34,7 @@ def testBS32(self, atol, rtol):
a = (torch.randint(0, 100, (2, 1, 128, 128))).to(torch.float32)
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
inputs_pad = _pad_last_2d(a, extended_list)
inputs_pad = F.pad(a, pad=extended_list)
weight_shape = [3, 3]
stride = [1, 1]
dilations = [1, 1]
Expand Down Expand Up @@ -62,7 +62,7 @@ def forward(self, a):
inputs = torch.rand([2, 1, 128, 128]) * 64
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
inputs_pad = _pad_last_2d(inputs, extended_list)
inputs_pad = F.pad(inputs, pad=extended_list)
ep = torch.export.export(
mod,
args=((inputs_pad).to(dtype),),
Expand Down

0 comments on commit 6f3f8c7

Please sign in to comment.