From c5fd1b42409bc963dbe8e59ca5a8fe6261316b9c Mon Sep 17 00:00:00 2001 From: beinggod Date: Fri, 30 Aug 2024 03:32:03 +0000 Subject: [PATCH 1/2] [PyTorch] Add FP8 padding and unpaading module 1. Add multi-tensor padding kernel 2. Add FP8Padding and Fp8Unpadding module 3. Add padding grouped linear UT case Signed-off-by: beinggod --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_multi_padding.cu | 169 +++++++++++++ tests/pytorch/test_numerics.py | 189 ++++++++++++++ transformer_engine/common/CMakeLists.txt | 1 + .../include/transformer_engine/padding.h | 51 ++++ transformer_engine/common/util/padding.cu | 233 ++++++++++++++++++ transformer_engine/pytorch/__init__.py | 1 + .../pytorch/cpp_extensions/__init__.py | 1 + .../pytorch/cpp_extensions/padding.py | 29 +++ transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 8 + .../pytorch/csrc/extensions/padding.cu | 79 ++++++ .../pytorch/csrc/extensions/pybind.cpp | 3 +- transformer_engine/pytorch/module/__init__.py | 2 + .../pytorch/module/fp8_padding.py | 123 +++++++++ .../pytorch/module/fp8_unpadding.py | 119 +++++++++ 16 files changed, 1009 insertions(+), 1 deletion(-) create mode 100644 tests/cpp/operator/test_multi_padding.cu create mode 100644 transformer_engine/common/include/transformer_engine/padding.h create mode 100644 transformer_engine/common/util/padding.cu create mode 100644 transformer_engine/pytorch/cpp_extensions/padding.py create mode 100644 transformer_engine/pytorch/csrc/extensions/padding.cu create mode 100644 transformer_engine/pytorch/module/fp8_padding.py create mode 100644 transformer_engine/pytorch/module/fp8_unpadding.py diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e590d8e92a..45806e7022 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -13,6 +13,7 @@ add_executable(test_operator test_layernorm.cu test_rmsnorm.cu test_multi_cast_transpose.cu + test_multi_padding.cu test_causal_softmax.cu ../test_common.cu) diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu new file mode 100644 index 0000000000..e9e42725fe --- /dev/null +++ b/tests/cpp/operator/test_multi_padding.cu @@ -0,0 +1,169 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const std::vector>& input_list, + std::vector>& output_list, + const std::vector& height_list, + const std::vector& width_list, + const std::vector& padded_height_list) { + using compute_t = float; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = input_list[tensor_id]; + auto& output = output_list[tensor_id]; + const size_t height = height_list[tensor_id]; + const size_t width = width_list[tensor_id]; + const size_t padded_height = padded_height_list[tensor_id]; + + for (size_t i = 0; i < padded_height; ++i) { + if (i < height) { + for (size_t j = 0; j < width; ++j) { + const compute_t x = static_cast(input[i * width + j]); + const OutputType y = static_cast(x); + output[i * width + j] = y; + } + } else { + for (size_t j = 0; j < width; ++j) { + output[i * width + j] = static_cast(0.f); + } + } + } + } +} + +template +void performTest() { + using namespace test; + + const DType itype = TypeInfo::dtype; + const DType otype = TypeInfo::dtype; + const std::vector> tensor_dims = {{1,1}, + {1,768}, + {768,1}, + {768,768}, + {43,43}, + {43,256}, + {256,43}, + {256,256}}; + const size_t num_tensors = tensor_dims.size(); + constexpr int align = 16; + + // Buffers for Transformer Engine implementation + std::vector input_list, output_list, output_t_list; + + // Buffers for reference implementation + std::vector> ref_input_list; + std::vector> ref_output_list; + std::vector ref_height_list(num_tensors), ref_width_list(num_tensors); + std::vector ref_padded_height_list(num_tensors); + + // Initialize buffers + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + const size_t height = tensor_dims[tensor_id].first; + const size_t width = tensor_dims[tensor_id].second; + const size_t padded_height = (height + align - 1) / align * align; + input_list.emplace_back(Tensor({ height, width }, itype)); + output_list.emplace_back(Tensor({ padded_height, width }, otype)); + + auto& input = input_list.back(); + auto& output = output_list.back(); + fillUniform(&input); + setRandomScale(&output); + + ref_input_list.emplace_back(height*width); + ref_output_list.emplace_back(padded_height*width); + + std::copy(input.cpu_dptr(), + input.cpu_dptr() + height * width, + ref_input_list.back().begin()); + ref_height_list[tensor_id] = height; + ref_width_list[tensor_id] = width; + ref_padded_height_list[tensor_id] = padded_height; + } + + // Transformer Engine implementation + auto make_nvte_vector = [](std::vector& tensor_list) + -> std::vector { + std::vector nvte_tensor_list; + for (auto& tensor : tensor_list) { + nvte_tensor_list.emplace_back(tensor.data()); + } + return nvte_tensor_list; + }; + nvte_multi_padding(num_tensors, + make_nvte_vector(input_list).data(), + make_nvte_vector(output_list).data(), + ref_padded_height_list.data(), + 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Reference implementation + compute_ref(ref_input_list, + ref_output_list, + ref_height_list, + ref_width_list, + ref_padded_height_list); + + // Check correctness + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + auto [atol, rtol] = getTolerances(otype); + compareResults("output", + output_list[tensor_id], + ref_output_list[tensor_id].data(), + atol, rtol); + } +} + +} // namespace + +class MultiPaddingTestSuite + : public ::testing::TestWithParam< + transformer_engine::DType> {}; + +TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = GetParam(); + const DType output_type = input_type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(); + ); + ); +} + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiPaddingTestSuite, + ::testing::ValuesIn(test::all_fp_types), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(info.param); + return name; + }); diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 85cd4fc256..cb1ab54ca1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional import pytest import copy +import random import torch import torch.nn as nn @@ -30,6 +31,8 @@ TransformerLayer, LayerNorm, InferenceParams, + Fp8Padding, + Fp8Unpadding, ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm @@ -354,6 +357,40 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return (input > 0) * input * input +class TorcGroupedLinearWithPadding(nn.Module): + + def __init__( + self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8 + ) -> None: + super().__init__() + + self.padding = Fp8Padding(num_gemms) + self.linear_fn = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + device="cuda", + ) + self.unpadding = Fp8Unpadding(num_gemms) + + self.fp8 = fp8 + + def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: + if self.fp8: + orig_m_splits = m_splits + inp, m_splits = self.padding(inp, m_splits) + + out = self.linear_fn(inp, m_splits) + + if self.fp8: + out = self.unpadding(out, orig_m_splits) + + return out + + _supported_act = { "geglu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"), @@ -1328,6 +1365,158 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ) +def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): + + def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): + """Padding tensor shapes to multiples of 16.""" + padded_tokens_per_expert = [ + (num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert + ] + hidden_states = torch.split(hidden_states, tokens_per_expert) + padded_hidden_states = [] + for hidden_state, actual_num_tokens, padded_num_tokens in zip( + hidden_states, tokens_per_expert, padded_tokens_per_expert + ): + padded_hidden_states.append(hidden_state) + if padded_num_tokens > actual_num_tokens: + pad_tensor = torch.zeros( + padded_num_tokens - actual_num_tokens, + hidden_state.shape[1], + dtype=hidden_state.dtype, + device=hidden_state.device, + ) + padded_hidden_states.append(pad_tensor) + padded_hidden_states = torch.cat(padded_hidden_states, dim=0) + return padded_hidden_states, padded_tokens_per_expert + + def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert): + inputmats = torch.split( + padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert + ) + hidden_states = torch.cat( + [ + grad_output_mat[: actual_tokens_per_expert[i]] + for i, grad_output_mat in enumerate(inputmats) + ], + dim=0, + ) + + return hidden_states + + def _generate_random_numbers(n, total_sum): + if n <= 0: + return [] + + # reset seed + random.seed(seed) + + breaks = sorted(random.sample(range(1, total_sum), n - 1)) + random_numbers = ( + [breaks[0]] + + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)] + + [total_sum - breaks[-1]] + ) + + return random_numbers + + reset_rng_states() + if fp8: + FP8GlobalStateManager.reset() + + inp_hidden_states = torch.randn( + (config.seq_len * bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_hidden_states.retain_grad() + + m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) + + with fp8_autocast(enabled=fp8): + if isinstance(block, TorcGroupedLinearWithPadding): + out = block(inp_hidden_states, m_splits) + else: + if fp8: + padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8( + inp_hidden_states, m_splits + ) + padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits) + out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits) + else: + out = block(inp_hidden_states, m_splits) + + loss = out.sum() + loss.backward() + + torch.cuda.synchronize() + outputs = [out, inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_padding_grouped_linear_accuracy( + dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None +): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs[model] + if config.seq_len % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + with fp8_model_init(enabled=fp8 and fp8_model_params): + grouped_linear = TorcGroupedLinearWithPadding( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + fp8=fp8, + ).eval() + + with fp8_model_init(enabled=fp8 and fp8_model_params): + ref_grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + ).eval() + + # Share params + with torch.no_grad(): + inner_grouped_linear = grouped_linear.linear_fn + for i in range(num_gemms): + setattr( + ref_grouped_linear, + f"weight{i}", + Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), + ) + + outputs = _test_padding_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, fp8 + ) + outputs_ref = _test_padding_grouped_linear_accuracy( + ref_grouped_linear, num_gemms, bs, dtype, config, fp8 + ) + + # Shoule be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): reset_rng_states() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a6fd6815c3..647d2c474d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -71,6 +71,7 @@ list(APPEND transformer_engine_SOURCES rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu util/cast.cu + util/padding.cu util/cuda_driver.cpp util/cuda_runtime.cpp util/rtc.cpp diff --git a/transformer_engine/common/include/transformer_engine/padding.h b/transformer_engine/common/include/transformer_engine/padding.h new file mode 100644 index 0000000000..a419b38234 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/padding.h @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file padding.h + * \brief Functions handling padding. + */ + +#ifndef TRANSFORMER_ENGINE_PADDING_H_ +#define TRANSFORMER_ENGINE_PADDING_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Padding multiple tensors. + * + * NOTE: Padding mode only support bottom. + * + * For example, 3x3 matrix pad to 4x3 matrix. + * + * source + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * + * destination + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * | 0 | 0 | 0 | + * + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors. + * \param[in,out] output_list List of padded tensors. Dimensions + * match tensors in input_list. + * \param[in] padded_num_rows_list List of padded num rows corresponding to input tensors. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* padded_num_rows_list, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_PADDING_H_ diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu new file mode 100644 index 0000000000..d7ed12f8e4 --- /dev/null +++ b/transformer_engine/common/util/padding.cu @@ -0,0 +1,233 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../utils.cuh" + +namespace transformer_engine { + +namespace { + +// Parameters to tune +constexpr int n_warps_per_tile = 4; +constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile; +constexpr int desired_load_size = 8; +constexpr int desired_store_size = 8; +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB + +struct MultiPaddingArgs { + // (input) Data buffers for input tensors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for cast output tensors + void* output_list[kMaxTensorsPerKernel]; + // Input matrix heights + int num_rows_list[kMaxTensorsPerKernel]; + // Input matrix heights (padded) + int padded_num_rows_list[kMaxTensorsPerKernel]; + // Input matrix widths + int row_length_list[kMaxTensorsPerKernel]; + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) { + using IVec = Vec; + using OVecC = Vec; + using OVecT = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr int bdimx = THREADS_PER_WARP; + constexpr int bdimy = n_warps_per_tile; + const int tid = threadIdx.x; + const int tidx = tid % bdimx; + const int tidy = tid / bdimx; + const int bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr int tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr int tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + + // Find tensor corresponding to block + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const IType* input = reinterpret_cast(args.input_list[tensor_id]); + OType* output = reinterpret_cast(args.output_list[tensor_id]); + const int num_rows = args.num_rows_list[tensor_id]; + const int padded_num_rows = args.padded_num_rows_list[tensor_id]; + const int row_length = args.row_length_list[tensor_id]; + + // Find position of tile within tensor + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int tile_id = bid - args.block_range[tensor_id]; + const int tile_id_m = tile_id / num_tiles_n; + const int tile_id_n = tile_id % num_tiles_n; + const int tile_row = tile_id_m * tile_dim_m; + const int tile_col = tile_id_n * tile_dim_n; + + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + OType local_zero = static_cast(0.f); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec_out; ++i2) { + const int row = tile_row + i1 * nvec_out + i2; + const int col = tile_col + j1 * nvec_in; + IVec local_input; + OVecC local_output; + local_input.clear(); + if (row < num_rows) { + for (int j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[row * row_length + col + j2]; + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec_in; ++j2) { + const CType x = CType(local_input.data.elt[j2]); + const OType y = OType(x); + local_output.data.elt[j2] = y; + } + if (row < num_rows) { + for (int j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_output.data.elt[j2]; + } + } + } else if (row < padded_num_rows) { + // padding + for (int j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_zero; + } + } + } + } + } +} + +} // namespace + +void multi_padding(const std::vector input_list, std::vector output_list, + const std::vector padded_num_rows_list, cudaStream_t stream) { + // Check that number of tensors is valid + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); + if (input_list.empty()) { + return; + } + + // Check that tensor properties are valid + DType itype = input_list[0]->data.dtype; + DType otype = output_list[0]->data.dtype; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = *input_list[tensor_id]; + const auto& output = *output_list[tensor_id]; + CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id)); + + NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "Output tensor types do not match."); + + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); + NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id], + "output tensor shape does not match padded input shape."); + } + + // Input matrices are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + const int tile_dim_m = THREADS_PER_WARP * desired_store_size / typeToSize(otype); + const int tile_dim_n = THREADS_PER_WARP * desired_load_size / typeToSize(itype); + + // Add tensors to kernel argument struct + MultiPaddingArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + // Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + itype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + constexpr int nvec_out = desired_store_size / sizeof(OutputType); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) + ); // NOLINT(*) + kernel_args.num_tensors = 0; + } + + // Calculate number of thread blocks needed for tensor + const int num_rows = input_list[tensor_id]->data.shape[0]; + const int padded_num_rows = padded_num_rows_list[tensor_id]; + const int row_length = input_list[tensor_id]->data.shape[1]; + const int num_tiles_m = (padded_num_rows + tile_dim_m - 1) / tile_dim_m; + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int num_tiles = num_tiles_m * num_tiles_n; + + // Add tensor to kernel argument struct + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); + kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.num_rows_list[pos] = num_rows; + kernel_args.padded_num_rows_list[pos] = padded_num_rows; + kernel_args.row_length_list[pos] = row_length; + kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; + kernel_args.num_tensors++; + } + + // Launch kernel + if (kernel_args.num_tensors > 0) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + itype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + constexpr int nvec_out = desired_store_size / sizeof(OutputType); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) + ); // NOLINT(*) + } +} + +} // namespace transformer_engine + +void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* padded_num_rows_list, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_padding); + using namespace transformer_engine; + std::vector input_list_, output_list_; + std::vector padded_num_rows_list_; + for (size_t i = 0; i < num_tensors; ++i) { + input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); + output_list_.push_back(reinterpret_cast(output_list[i])); + padded_num_rows_list_.push_back(padded_num_rows_list[i]); + } + multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); +} diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 1c755491b0..63151b1055 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -38,6 +38,7 @@ def _load_library(): from transformer_engine.pytorch.module import LayerNorm from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import GroupedLinear +from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import destroy_ub from transformer_engine.pytorch.attention import DotProductAttention diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 61d688f3f4..9f3c1b2424 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -11,3 +11,4 @@ from .activation import * from .normalization import * from .cast import * +from .padding import * diff --git a/transformer_engine/pytorch/cpp_extensions/padding.py b/transformer_engine/pytorch/cpp_extensions/padding.py new file mode 100644 index 0000000000..41dfbe2466 --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/padding.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Python interface for transpose extensions""" +from typing import List, Tuple, Union +import torch +import transformer_engine_torch as tex + + +__all__ = [ + "multi_padding_fused", +] + + +def multi_padding_fused( + inp: torch.Tensor, + row_list: List[int], + padded_row_list: List[int], + out: torch.Tensor, +) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], None]: + """Padding""" + + tex.fused_multi_row_padding( + inp, + out, + row_list, + padded_row_list, + ) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 7fb9953f94..04a1193a71 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1a6f5f157e..0ccd9169d7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -475,4 +475,12 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, float momentum, float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale); +/*************************************************************************************************** + * padding + **************************************************************************************************/ + +void fused_multi_row_padding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector padded_input_row_list); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cu b/transformer_engine/pytorch/csrc/extensions/padding.cu new file mode 100644 index 0000000000..d975ebeeef --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/padding.cu @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +void fused_multi_row_padding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector padded_input_row_list) { + using namespace transformer_engine; + + NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), + "Number of input row list and padded row list must match."); + NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); + NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); + + const int num_tensors = input_row_list.size(); + // Extract properties from PyTorch tensors + std::vector input_dptr_list, output_dptr_list; + std::vector> input_shape_list, output_shape_list; + std::vector input_type_list; + void* d_input_ptr = reinterpret_cast(input.data_ptr()); + void* d_output_ptr = reinterpret_cast(output.data_ptr()); + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + input_dptr_list.push_back(d_input_ptr); + output_dptr_list.push_back(d_output_ptr); + + // Move the input pointer to the next split. + char* input_char_ptr = reinterpret_cast(d_input_ptr); + const size_t input_dptr_offset = + input_row_list[tensor_id] * input.size(1) * input.element_size(); + input_char_ptr += input_dptr_offset; + d_input_ptr = reinterpret_cast(input_char_ptr); + + input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); + + // Move the output pointer to the next split. + char* output_char_ptr = reinterpret_cast(d_output_ptr); + const size_t output_dptr_offset = + padded_input_row_list[tensor_id] * output.size(1) * output.element_size(); + output_char_ptr += output_dptr_offset; + d_output_ptr = reinterpret_cast(output_char_ptr); + + output_shape_list.push_back( + {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + } + + // Construct TE tensors + std::vector nvte_input_list, nvte_output_list; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype) -> NVTETensor { + tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); + return tensor_wrappers.back().data(); + }; + + std::vector padded_num_rows_list; + for (size_t i = 0; i < input_dptr_list.size(); ++i) { + if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue; + nvte_input_list.emplace_back( + make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i])); + nvte_output_list.emplace_back( + make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i])); + padded_num_rows_list.emplace_back(padded_input_row_list[i]); + } + + // Check tensor lists + NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(), + "Number of input and output tensors must match"); + NVTE_CHECK(padded_num_rows_list.size() == nvte_input_list.size() && + "Number of input and padded row list must match"); + + // Launch TE kernel + nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), + padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f903a1c35b..d95af17e6f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -146,7 +146,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction", py::call_guard()); - + m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", + py::call_guard()); // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 6994f586b1..ba4755efe3 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -9,4 +9,6 @@ from .layernorm_mlp import LayerNormMLP from .layernorm import LayerNorm from .rmsnorm import RMSNorm +from .fp8_padding import Fp8Padding +from .fp8_unpadding import Fp8Unpadding from .base import initialize_ub, destroy_ub diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py new file mode 100644 index 0000000000..60bac91353 --- /dev/null +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -0,0 +1,123 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FP8 Padding API""" + +from typing import Union, List + +import torch + +from ..cpp_extensions import ( + multi_padding_fused, +) +from ..jit import no_torch_dynamo + + +__all__ = ["Fp8Padding"] + + +class _Fp8Padding(torch.autograd.Function): + """functional FP8 padding""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + m_splits: List[int], + padded_m_splits: List[int], + is_grad_enabled: bool, + ) -> torch.Tensor: + # Make sure input dimensions are compatible + in_features = inp.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(padded_m_splits) + out = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device) + + multi_padding_fused(inp.view(-1, in_features), m_splits, padded_m_splits, out) + + if is_grad_enabled: + ctx.m_splits = m_splits + ctx.padded_m_splits = padded_m_splits + ctx.requires_dgrad = inp.requires_grad + + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + + grad_input = None + if ctx.requires_dgrad: + grad_output = grad_output.contiguous() + + grad_output_mats = torch.split( + grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits + ) + grad_input = torch.cat( + [ + grad_output_mat[: ctx.m_splits[i]] + for i, grad_output_mat in enumerate(grad_output_mats) + ], + dim=0, + ) + + return (grad_input, None, None, None) + + +class Fp8Padding(torch.nn.Module): + """ + Apply the padding for Grouped GEMM input. + + Parameters + ---------- + num_gemms: int + number of GEMMs to be performed simutaneously. + """ + + def __init__( + self, + num_gemms, + ) -> None: + super().__init__() + + self.num_gemms = num_gemms + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + m_splits: List[int], + ) -> Union[torch.Tensor, List[int]]: + """ + Apply the padding to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + m_splits : List[int] + List of integers representing the split of the input tensor. + """ + + assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + + # FP8 padding calculate + padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + + if torch.is_grad_enabled(): + fn = _Fp8Padding.apply + args = [] + else: + fn = _Fp8Padding.forward + args = [None] + + args += ( + inp, + m_splits, + padded_m_splits, + torch.is_grad_enabled(), + ) + out = fn(*args) + + return out, padded_m_splits diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py new file mode 100644 index 0000000000..6e08f849ef --- /dev/null +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""FP8 Padding API""" + +from typing import List + +import torch + +from ..cpp_extensions import ( + multi_padding_fused, +) +from ..jit import no_torch_dynamo + + +__all__ = ["Fp8Unpadding"] + + +class _Fp8Unpadding(torch.autograd.Function): + """functional FP8 unpadding""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + m_splits: List[int], + padded_m_splits: List[int], + is_grad_enabled: bool, + ) -> torch.Tensor: + inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) + out_ret = torch.cat( + [grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 + ) + + if is_grad_enabled: + ctx.m_splits = m_splits + ctx.padded_m_splits = padded_m_splits + ctx.requires_dgrad = inp.requires_grad + + return out_ret + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + grad_input = None + if ctx.requires_dgrad: + grad_output = grad_output.contiguous() + + in_features = grad_output.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(ctx.padded_m_splits) + grad_input = torch.empty( + [total_row, in_features], dtype=grad_output.dtype, device=grad_output.device + ) + # FP8 pad input for forward, FP8 input transpose for backward wgrad + multi_padding_fused( + grad_output.view(-1, in_features), ctx.m_splits, ctx.padded_m_splits, grad_input + ) + + return (grad_input, None, None, None) + + +class Fp8Unpadding(torch.nn.Module): + """ + Apply the unpadding for Grouped GEMM input. + + Parameters + ---------- + num_gemms: int + number of GEMMs to be performed simutaneously. + """ + + def __init__( + self, + num_gemms, + ) -> None: + super().__init__() + + self.num_gemms = num_gemms + + @no_torch_dynamo() + def forward( + self, + inp: torch.Tensor, + m_splits: List[int], + ) -> torch.Tensor: + """ + Apply the unpadding to the input. + + Parameters + ---------- + inp : torch.Tensor + Input tensor. + m_splits : List[int] + List of integers representing the split of the input tensor. + """ + + assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + + # FP8 padding calculate + padded_m_splits = [(m + 15) // 16 * 16 for m in m_splits] + + if torch.is_grad_enabled(): + fn = _Fp8Unpadding.apply + args = [] + else: + fn = _Fp8Unpadding.forward + args = [None] + + args += ( + inp, + m_splits, + padded_m_splits, + torch.is_grad_enabled(), + ) + out = fn(*args) + + return out From 27d808261b75e00470e515180399d2b72b78ee08 Mon Sep 17 00:00:00 2001 From: beinggod Date: Wed, 4 Sep 2024 02:43:58 +0000 Subject: [PATCH 2/2] refine ut & simplify multi-padding kernel Signed-off-by: beinggod --- tests/pytorch/test_numerics.py | 8 +-- transformer_engine/common/util/padding.cu | 82 ++++++++++------------- 2 files changed, 38 insertions(+), 52 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index cb1ab54ca1..723f68369b 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -357,7 +357,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return (input > 0) * input * input -class TorcGroupedLinearWithPadding(nn.Module): +class TorchGroupedLinearWithPadding(nn.Module): def __init__( self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8 @@ -1434,7 +1434,7 @@ def _generate_random_numbers(n, total_sum): m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) with fp8_autocast(enabled=fp8): - if isinstance(block, TorcGroupedLinearWithPadding): + if isinstance(block, TorchGroupedLinearWithPadding): out = block(inp_hidden_states, m_splits) else: if fp8: @@ -1461,7 +1461,7 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) -@pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("fp8", [True]) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None @@ -1474,7 +1474,7 @@ def test_padding_grouped_linear_accuracy( pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params): - grouped_linear = TorcGroupedLinearWithPadding( + grouped_linear = TorchGroupedLinearWithPadding( num_gemms, config.hidden_size, 4 * config.hidden_size, diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index d7ed12f8e4..017d2e6a56 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -21,8 +21,7 @@ namespace { // Parameters to tune constexpr int n_warps_per_tile = 4; constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile; -constexpr int desired_load_size = 8; -constexpr int desired_store_size = 8; +constexpr int desired_load_store_size = 8; constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB struct MultiPaddingArgs { @@ -42,11 +41,9 @@ struct MultiPaddingArgs { int num_tensors; }; -template +template __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) { - using IVec = Vec; - using OVecC = Vec; - using OVecT = Vec; + using Vec = Vec; // Thread indices // Note: Block is interpreted as a warp_size x num_warps grid @@ -58,11 +55,11 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP const int bid = blockIdx.x; // Input tensors are divided into tiles - // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles - constexpr int tile_dim_m = THREADS_PER_WARP * nvec_out; - constexpr int tile_dim_n = THREADS_PER_WARP * nvec_in; + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + constexpr int tile_dim_m = THREADS_PER_WARP * nvec; + constexpr int tile_dim_n = THREADS_PER_WARP * nvec; - // Number of nvec_out x nvec_in subtiles for each thread to + // Number of nvec x nvec subtiles for each thread to // load/store constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; @@ -71,8 +68,8 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP while (args.block_range[tensor_id + 1] <= bid) { ++tensor_id; } - const IType* input = reinterpret_cast(args.input_list[tensor_id]); - OType* output = reinterpret_cast(args.output_list[tensor_id]); + const Type* input = reinterpret_cast(args.input_list[tensor_id]); + Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int padded_num_rows = args.padded_num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; @@ -88,40 +85,38 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP // Load input and store to registers // Note: Each thread loads n_iterations subtiles, casts to output // type, and transposes in registers. - OType local_zero = static_cast(0.f); + Type local_zero = static_cast(0.f); #pragma unroll for (int iter = 0; iter < n_iterations; ++iter) { const int i1 = tidy + iter * bdimy; const int j1 = tidx; #pragma unroll - for (int i2 = 0; i2 < nvec_out; ++i2) { - const int row = tile_row + i1 * nvec_out + i2; - const int col = tile_col + j1 * nvec_in; - IVec local_input; - OVecC local_output; + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + Vec local_input; + Vec local_output; local_input.clear(); if (row < num_rows) { - for (int j2 = 0; j2 < nvec_in; ++j2) { + for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { local_input.data.elt[j2] = input[row * row_length + col + j2]; } } } #pragma unroll - for (int j2 = 0; j2 < nvec_in; ++j2) { - const CType x = CType(local_input.data.elt[j2]); - const OType y = OType(x); - local_output.data.elt[j2] = y; + for (int j2 = 0; j2 < nvec; ++j2) { + local_output.data.elt[j2] = local_input.data.elt[j2]; } if (row < num_rows) { - for (int j2 = 0; j2 < nvec_in; ++j2) { + for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { output[row * row_length + col + j2] = local_output.data.elt[j2]; } } } else if (row < padded_num_rows) { // padding - for (int j2 = 0; j2 < nvec_in; ++j2) { + for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { output[row * row_length + col + j2] = local_zero; } @@ -143,16 +138,15 @@ void multi_padding(const std::vector input_list, std::vector o } // Check that tensor properties are valid - DType itype = input_list[0]->data.dtype; - DType otype = output_list[0]->data.dtype; + DType type = input_list[0]->data.dtype; for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { const auto& input = *input_list[tensor_id]; const auto& output = *output_list[tensor_id]; CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id)); CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id)); - NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match."); - NVTE_CHECK(output.data.dtype == otype, "Output tensor types do not match."); + NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match."); + NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match."); NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id], @@ -160,9 +154,9 @@ void multi_padding(const std::vector input_list, std::vector o } // Input matrices are divided into tiles - // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles - const int tile_dim_m = THREADS_PER_WARP * desired_store_size / typeToSize(otype); - const int tile_dim_n = THREADS_PER_WARP * desired_load_size / typeToSize(itype); + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); // Add tensors to kernel argument struct MultiPaddingArgs kernel_args; @@ -172,14 +166,10 @@ void multi_padding(const std::vector input_list, std::vector o // Launch kernel if argument struct is full if (kernel_args.num_tensors == kMaxTensorsPerKernel) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - itype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); - constexpr int nvec_out = desired_store_size / sizeof(OutputType); - const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; - multi_padding_kernel - <<>>(kernel_args);); // NOLINT(*) - ); // NOLINT(*) + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) kernel_args.num_tensors = 0; } @@ -205,14 +195,10 @@ void multi_padding(const std::vector input_list, std::vector o // Launch kernel if (kernel_args.num_tensors > 0) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - itype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); - constexpr int nvec_out = desired_store_size / sizeof(OutputType); - const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; - multi_padding_kernel - <<>>(kernel_args);); // NOLINT(*) - ); // NOLINT(*) + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) } }