Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Implement Fp8 padding and unpadding module #1129

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_executable(test_operator
test_layernorm.cu
test_rmsnorm.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu
../test_common.cu)

Expand Down
169 changes: 169 additions & 0 deletions tests/cpp/operator/test_multi_padding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include <cstdio>

#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/padding.h>
#include "../test_common.h"

using namespace transformer_engine;

namespace {

template <typename InputType, typename OutputType>
void compute_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list,
const std::vector<int>& padded_height_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output = output_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
const size_t padded_height = padded_height_list[tensor_id];

for (size_t i = 0; i < padded_height; ++i) {
if (i < height) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(x);
output[i * width + j] = y;
}
} else {
for (size_t j = 0; j < width; ++j) {
output[i * width + j] = static_cast<OutputType>(0.f);
}
}
}
}
}

template <typename InputType, typename OutputType>
void performTest() {
using namespace test;

const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
constexpr int align = 16;

// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_list, output_t_list;

// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_list;
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
std::vector<int> ref_padded_height_list(num_tensors);

// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (height + align - 1) / align * align;
input_list.emplace_back(Tensor({ height, width }, itype));
output_list.emplace_back(Tensor({ padded_height, width }, otype));

auto& input = input_list.back();
auto& output = output_list.back();
fillUniform(&input);
setRandomScale(&output);

ref_input_list.emplace_back(height*width);
ref_output_list.emplace_back(padded_height*width);

std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
ref_padded_height_list[tensor_id] = padded_height;
}

// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
nvte_multi_padding(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(output_list).data(),
ref_padded_height_list.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

// Reference implementation
compute_ref<InputType, OutputType>(ref_input_list,
ref_output_list,
ref_height_list,
ref_width_list,
ref_padded_height_list);

// Check correctness
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol, rtol] = getTolerances(otype);
compareResults("output",
output_list[tensor_id],
ref_output_list[tensor_id].data(),
atol, rtol);
}
}

} // namespace

class MultiPaddingTestSuite
: public ::testing::TestWithParam<
transformer_engine::DType> {};

TEST_P(MultiPaddingTestSuite, TestMultiPaddingTranspose) {
using namespace transformer_engine;
using namespace test;

const DType input_type = GetParam();
const DType output_type = input_type;

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>();
);
);
}


INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiPaddingTestSuite,
::testing::ValuesIn(test::all_fp_types),
[](const testing::TestParamInfo<MultiPaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(info.param);
return name;
});
189 changes: 189 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, List, Optional
import pytest
import copy
import random

import torch
import torch.nn as nn
Expand All @@ -30,6 +31,8 @@
TransformerLayer,
LayerNorm,
InferenceParams,
Fp8Padding,
Fp8Unpadding,
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
Expand Down Expand Up @@ -354,6 +357,40 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input > 0) * input * input


class TorcGroupedLinearWithPadding(nn.Module):
phu0ngng marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8
) -> None:
super().__init__()

self.padding = Fp8Padding(num_gemms)
self.linear_fn = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
device="cuda",
)
self.unpadding = Fp8Unpadding(num_gemms)
phu0ngng marked this conversation as resolved.
Show resolved Hide resolved

self.fp8 = fp8

def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
if self.fp8:
orig_m_splits = m_splits
inp, m_splits = self.padding(inp, m_splits)

out = self.linear_fn(inp, m_splits)

if self.fp8:
out = self.unpadding(out, orig_m_splits)

return out


_supported_act = {
"geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"),
Expand Down Expand Up @@ -1328,6 +1365,158 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)


def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):

def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
"""Padding tensor shapes to multiples of 16."""
padded_tokens_per_expert = [
(num_tokens + 15) // 16 * 16 for num_tokens in tokens_per_expert
]
hidden_states = torch.split(hidden_states, tokens_per_expert)
padded_hidden_states = []
for hidden_state, actual_num_tokens, padded_num_tokens in zip(
hidden_states, tokens_per_expert, padded_tokens_per_expert
):
padded_hidden_states.append(hidden_state)
if padded_num_tokens > actual_num_tokens:
pad_tensor = torch.zeros(
padded_num_tokens - actual_num_tokens,
hidden_state.shape[1],
dtype=hidden_state.dtype,
device=hidden_state.device,
)
padded_hidden_states.append(pad_tensor)
padded_hidden_states = torch.cat(padded_hidden_states, dim=0)
return padded_hidden_states, padded_tokens_per_expert

def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert):
inputmats = torch.split(
padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert
)
hidden_states = torch.cat(
[
grad_output_mat[: actual_tokens_per_expert[i]]
for i, grad_output_mat in enumerate(inputmats)
],
dim=0,
)

return hidden_states

def _generate_random_numbers(n, total_sum):
if n <= 0:
return []

# reset seed
random.seed(seed)

breaks = sorted(random.sample(range(1, total_sum), n - 1))
random_numbers = (
[breaks[0]]
+ [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]
+ [total_sum - breaks[-1]]
)

return random_numbers

reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()

inp_hidden_states = torch.randn(
(config.seq_len * bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()

m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)

with fp8_autocast(enabled=fp8):
if isinstance(block, TorcGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits)
else:
if fp8:
padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
inp_hidden_states, m_splits
)
padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
else:
out = block(inp_hidden_states, m_splits)

loss = out.sum()
loss.backward()

torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
phu0ngng marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)

config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")

with fp8_model_init(enabled=fp8 and fp8_model_params):
grouped_linear = TorcGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()

with fp8_model_init(enabled=fp8 and fp8_model_params):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
).eval()

# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)

outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, fp8
)

# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states()

Expand Down
Loading