Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove MmaOpDetails::input_layout and getInputLayout #3322

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 0 additions & 58 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1202,55 +1202,6 @@ TensorViewDetails getDetailsFor(const std::vector<IterDomain*>& dims) {
return details;
}

MmaLayout getInputLayout(
const TensorViewDetails& in_a,
const TensorViewDetails& in_b,
const MmaOp::AxesData& m_axes,
const MmaOp::AxesData& n_axes,
const MmaOp::AxesData& k_axes) {
// TT layout (b - broadcast, r - reduction):
// A = [M, K, b]
// B = [b, K, N]
// C = [M, r, N] (root domain)
if ((m_axes.front() < in_a.bcasts.front()) &&
(k_axes.front() < in_a.bcasts.front()) &&
(in_b.bcasts.front() < k_axes.front()) &&
(in_b.bcasts.front() < n_axes.front())) {
return MmaLayout::TT;
}
// TN layout (b - broadcast, r - reduction):
// A = [M, b, K]
// B = [b, N, K]
// C = [M, N, r] (root domain)
if ((m_axes.front() < in_a.bcasts.front()) &&
(in_a.bcasts.front() < k_axes.front()) &&
(in_b.bcasts.front() < n_axes.front()) &&
(in_b.bcasts.front() < k_axes.front())) {
return MmaLayout::TN;
}
// NT layout (b - broadcast, r - reduction):
// A = [K, M, b]
// B = [K, b, N]
// C = [r, M, N] (root domain)
if ((k_axes.front() < in_a.bcasts.front()) &&
(m_axes.front() < in_a.bcasts.front()) &&
(k_axes.front() < in_b.bcasts.front()) &&
(in_b.bcasts.front() < n_axes.front())) {
return MmaLayout::NT;
}
// NN layout (b - broadcast, r - reduction):
// A = [b, K, M]
// B = [N, K, b]
// C = [N, r, M] (root domain)
if ((in_a.bcasts.front() < k_axes.front()) &&
(k_axes.front() < m_axes.front()) && (n_axes.front() < k_axes.front()) &&
(k_axes.front() < in_b.bcasts.front())) {
return MmaLayout::NN;
}

NVF_THROW("Unsupported input layout");
}

MmaOpDetails getMmaOpDetails(
TensorView* out,
TensorView* in_a,
Expand Down Expand Up @@ -1383,15 +1334,6 @@ MmaOpDetails getMmaOpDetails(
!details.k_axes.empty(),
"MmaOp inputs must define at least a single K dimension");

// TODO: for tensor contraction / split-k uses of MmaOp different input layout
// rules may be needed
details.input_layout = getInputLayout(
in_a_details,
in_b_details,
details.m_axes,
details.n_axes,
details.k_axes);

return details;
}

Expand Down
2 changes: 0 additions & 2 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ struct MmaOpDetails {
// Concrete or broadcast axes that are present in all inputs
// and output
AxesData batch_axes;
// A placeholder for mma input layout
std::optional<MmaLayout> input_layout = std::nullopt;
};

// A helper structure with pieces of information about TensorView
Expand Down
63 changes: 63 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,69 @@ TEST_P(MatmulTestWithLayout, AmpereMatmul) {
NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
}

// Single batch dimension which is broadcast
TEST_P(MatmulTestWithLayout, AmpereMatmulBroadcastBatch) {
// Keep multiples of 8 to keep vectorizable.
int M = 504, N = 136, K = 248;

Fusion fusion;
FusionGuard fg(&fusion);

auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

fusion.addInput(tv0);
fusion.addInput(tv1);

tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
// Broadcast inputs to 1, M, 1, K and 1, 1, N, K
tv0 = broadcast(tv0, {true, false, false, false});
tv1 = broadcast(tv1, {true, false, false, false});
auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

fusion.addOutput(tv2);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 128, 32);
gemm_tile.warp_tile = GemmTile(64, 64, 32);
gemm_tile.instruction_tile = GemmTile(16, 8, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};
mparams.mma_macro = MmaMacro::Ampere_16_8_16;
mparams.tile_sizes = gemm_tile;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

auto inputs = matmulAtInput3DTuring(M, N, K, layout);

FusionExecutor fe;
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8,
0,
fe.compileFusion(
&fusion,
{inputs.first, inputs.second},
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref =
atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout)
.unsqueeze(0);
NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
}

TEST_P(MatmulTestWithLayout, AmperePrologueFusionBroadcast) {
// Keep multiples of 8 to keep vectorizable.
int M = 504, N = 136, K = 248;
Expand Down
Loading