diff --git a/sharktank/sharktank/ops/qconv_impls.py b/sharktank/sharktank/ops/qconv_impls.py index 2cb08bb68..add64a605 100644 --- a/sharktank/sharktank/ops/qconv_impls.py +++ b/sharktank/sharktank/ops/qconv_impls.py @@ -12,6 +12,7 @@ import warnings import torch +import torch.nn.functional as F from sharktank import kernels @@ -119,7 +120,7 @@ def qconv2d_tensor_scaled( padding = _expand_int_to_2_tuple(padding) dilation = _expand_int_to_2_tuple(dilation) extended_padding_list = [item for item in padding for _ in range(2)] - padded_input = _pad_last_2d(input_qs, extended_padding_list) + padded_input = F.pad(input_qs, pad=extended_padding_list) y_qs = _invoke_conv2d_kernel( padded_input, weight_qs, @@ -258,27 +259,6 @@ def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dt return output -def _pad_last_2d(input_tensor, pad_width): - # pad_width should be in the format [pad_left, pad_right, pad_top, pad_bottom] - pad_left, pad_right, pad_top, pad_bottom = pad_width - batch_size, channels, height, width = input_tensor.shape - - # Create a new tensor with the desired padded size filled with zeros - padded_height = height + pad_top + pad_bottom - padded_width = width + pad_left + pad_right - padded_tensor = torch.zeros( - (batch_size, channels, padded_height, padded_width), - dtype=input_tensor.dtype, - device=input_tensor.device, - ) - - # Copy the values from the input tensor to the appropriate location in the padded tensor - padded_tensor[ - :, :, pad_top : pad_top + height, pad_left : pad_left + width - ] = input_tensor - return padded_tensor - - def _flatten_input_scale_offset_channels(d, m): """Flattens either a 4d or 0d scale/offset as [N, C, H, W] to 1D. diff --git a/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py b/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py index 637bf74c8..ff1430a1a 100644 --- a/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py +++ b/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py @@ -12,10 +12,10 @@ from parameterized import parameterized import torch +import torch.nn.functional as F from iree.turbine import aot from sharktank import kernels -from sharktank.ops.qconv_impls import _pad_last_2d class conv_2d_nchw_fchw_test(unittest.TestCase): @@ -36,7 +36,8 @@ def testBS32(self, input_dtype, output_dtype_name, atol, rtol): inputs = (torch.rand([2, 4, 64, 64]) * 64).to(input_dtype) padding = [1, 1] extended_list = [item for item in padding for _ in range(2)] - inputs_pad = _pad_last_2d(inputs, extended_list) + inputs_pad = F.pad(inputs, pad=extended_list) + weights = (torch.rand([8, 4, 3, 3]) * 64).to(input_dtype) bias = (torch.rand([8]) * 64).to(dtype=output_dtype) result = kernels.conv_2d_nchw_fchw( @@ -68,7 +69,7 @@ def forward(self, a, b, c): inputs = torch.rand([2, 320, 64, 64]) * 64 padding = [1, 1] extended_list = [item for item in padding for _ in range(2)] - inputs_pad = _pad_last_2d(inputs, extended_list) + inputs_pad = F.pad(inputs, pad=extended_list) ep = torch.export.export( mod, args=( diff --git a/sharktank/tests/kernels/pooling_nchw_sum_test.py b/sharktank/tests/kernels/pooling_nchw_sum_test.py index 205391d96..e512eb484 100644 --- a/sharktank/tests/kernels/pooling_nchw_sum_test.py +++ b/sharktank/tests/kernels/pooling_nchw_sum_test.py @@ -12,10 +12,10 @@ from parameterized import parameterized import torch +import torch.nn.functional as F from iree.turbine import aot from sharktank import kernels -from sharktank.ops.qconv_impls import _pad_last_2d class pooling_nchw_sum_test(unittest.TestCase): @@ -34,7 +34,7 @@ def testBS32(self, atol, rtol): a = (torch.randint(0, 100, (2, 1, 128, 128))).to(torch.float32) padding = [1, 1] extended_list = [item for item in padding for _ in range(2)] - inputs_pad = _pad_last_2d(a, extended_list) + inputs_pad = F.pad(a, pad=extended_list) weight_shape = [3, 3] stride = [1, 1] dilations = [1, 1] @@ -62,7 +62,7 @@ def forward(self, a): inputs = torch.rand([2, 1, 128, 128]) * 64 padding = [1, 1] extended_list = [item for item in padding for _ in range(2)] - inputs_pad = _pad_last_2d(inputs, extended_list) + inputs_pad = F.pad(inputs, pad=extended_list) ep = torch.export.export( mod, args=((inputs_pad).to(dtype),),