diff --git a/test/unit/util/CMakeLists.txt b/test/unit/util/CMakeLists.txt index 449d6f62a8..0e88b065bc 100644 --- a/test/unit/util/CMakeLists.txt +++ b/test/unit/util/CMakeLists.txt @@ -30,4 +30,5 @@ cutlass_test_unit_add_executable( cutlass_test_unit_util tensor_reduce.cu cutlass_test_levels.cu + rms_norm.cu ) diff --git a/test/unit/util/rms_norm.cu b/test/unit/util/rms_norm.cu new file mode 100644 index 0000000000..a3e6595dae --- /dev/null +++ b/test/unit/util/rms_norm.cu @@ -0,0 +1,123 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include "../common/cutlass_unit_test.h" + +#include "cutlass/util/device_rmsnorm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/constants.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +using ElementType = cutlass::half_t; +using Layout = cutlass::layout::RowMajor; + +void rmsnorm_host(cutlass::MatrixCoord tensor_size, + cutlass::TensorRef output, + cutlass::TensorRef input, + cutlass::TensorRef weight) { + const int M = tensor_size.row(); + const int N = tensor_size.column(); + + for (int m = 0; m < M; ++m) { + float square_sum{0}; + + for (int n = 0; n < N; ++n) { + float inp = static_cast(input.at({m, n})); + square_sum += inp * inp; + } + + float sq_mean = square_sum / (float)N; + float sqrt_var = cutlass::fast_sqrt(sq_mean + (float)1e-6); + + for (int n = 0; n < N; ++n) { + float inp = static_cast(input.at({m, n})); + float g = static_cast(weight.at({0, n})); + float res_fp32 = inp / sqrt_var * g; + output.at({m, n}) = ElementType(res_fp32); + } + } +} + +void run_test(int M, int N) { + cutlass::HostTensor input, output_ref, output, weight; + input.reset({M, N}); + output.reset({M, N}); + output_ref.reset({M, N}); + weight.reset({1, N}); + + const unsigned seed = 2022; + + cutlass::reference::host::TensorFillRandomUniform(input.host_view(), + seed, + ElementType(5), + ElementType(-5), + 0); + + cutlass::reference::host::TensorFillRandomUniform(weight.host_view(), + seed, + ElementType(5), + ElementType(-5), + 0); + + input.sync_device(); + weight.sync_device(); + + rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref()); + cutlass::rmsnorm({M, N}, output.device_ref(), + input.device_ref(), weight.device_ref(), NULL); + + output.sync_host(); + + float max_abs_diff = -1; + float mean_abs_diff = 0; + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + auto diff = abs(static_cast(output_ref.at({m, n}) - output.at({m, n}))); + mean_abs_diff += diff; + max_abs_diff = max(max_abs_diff, diff); + } + } + + mean_abs_diff /= float(M * N); + + EXPECT_TRUE(max_abs_diff < 0.001f && mean_abs_diff < 0.001f) + << "Max absolute difference : " << max_abs_diff << "\n" + << "Mean absolute difference: " << mean_abs_diff; +} + +TEST(RMSNorm, 16x1024) { + run_test(16, 1024); +} + +TEST(RMSNorm, 1x127) { + run_test(1, 127); +} diff --git a/tools/util/include/cutlass/util/device_rmsnorm.h b/tools/util/include/cutlass/util/device_rmsnorm.h new file mode 100644 index 0000000000..5090efa0df --- /dev/null +++ b/tools/util/include/cutlass/util/device_rmsnorm.h @@ -0,0 +1,185 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/device_utils.h" +#include + +namespace cutlass { + +__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, + const float4 *weight, + const int m, const int n) { + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + const int n_8 = n / 8; + int offset = m_idx * n_8; + input += offset; + output += offset; + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const half2 *h1 = (half2 *)&local_val.x; + const half2 *h2 = (half2 *)&local_val.y; + const half2 *h3 = (half2 *)&local_val.z; + const half2 *h4 = (half2 *)&local_val.w; + local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + + static_cast(h1->y) * static_cast(h1->y) + + static_cast(h2->x) * static_cast(h2->x) + + static_cast(h2->y) * static_cast(h2->y) + + static_cast(h3->x) * static_cast(h3->x) + + static_cast(h3->y) * static_cast(h3->y) + + static_cast(h4->x) * static_cast(h4->x) + + static_cast(h4->y) * static_cast(h4->y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + 1e-6); + } + __syncthreads(); + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const float4 weight_val = weight[index]; + + const half2 *l1 = (half2 *)&local_val.x; + const half2 *l2 = (half2 *)&local_val.y; + const half2 *l3 = (half2 *)&local_val.z; + const half2 *l4 = (half2 *)&local_val.w; + + const half2 *g1 = (half2 *)&weight_val.x; + const half2 *g2 = (half2 *)&weight_val.y; + const half2 *g3 = (half2 *)&weight_val.z; + const half2 *g4 = (half2 *)&weight_val.w; + + float4 tmp; + half2 *h1 = (half2 *)&tmp.x; + half2 *h2 = (half2 *)&tmp.y; + half2 *h3 = (half2 *)&tmp.z; + half4 *h4 = (half4 *)&tmp.w; + + h1->x = half(static_cast(l1->x) * s_mean * static_cast(g1->x)); + h1->y = half(static_cast(l1->y) * s_mean * static_cast(g1->y)); + h2->x = half(static_cast(l2->x) * s_mean * static_cast(g2->x)); + h2->y = half(static_cast(l2->y) * s_mean * static_cast(g2->y)); + h3->x = half(static_cast(l3->x) * s_mean * static_cast(g3->x)); + h3->y = half(static_cast(l3->y) * s_mean * static_cast(g3->y)); + h4->x = half(static_cast(l4->x) * s_mean * static_cast(g4->x)); + h4->y = half(static_cast(l4->y) * s_mean * static_cast(g4->y)); + + output[index] = tmp; + } +} + +template +__global__ void rmsnorm_twoPassAlgo_e1(T* output, + const T* input, + const T* weight, + const int m, const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val * local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + 1e-6); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T weight_val = weight[index]; + const T local_val = input[index]; + output[index] = T(static_cast(local_val) * s_mean * static_cast(weight_val)); + } +} + +template +void rmsnorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_weight, + cudaStream_t stream){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* weight = ref_weight.data(); + dim3 grid(m); + + if (n % 8 == 0 && std::is_same::value) { + dim3 block(min(1024, (n / 8 + 31) / 32 * 32)); + + rmsnorm_twoPassAlgo_e8<<>>( + (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n); + } else { + dim3 block(min(1024, ((n + 31)/32 + 31)/32*32)); + + rmsnorm_twoPassAlgo_e1<<>>( + output, input, weight, m, n); + } + + auto result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl; + abort(); + } +} + +} // namespace cutlass