Skip to content

Commit

Permalink
Create HopperMultiMatmulScheduler (#3243)
Browse files Browse the repository at this point in the history
This PR creates `HopperMultiMatmulScheduler`.

## Changes:
* `scheduleMultipleMatmuls` in `multi_matmul.h` selects between
`AmpereMultiMatmulScheduler` and `HopperMultiMatmulScheduler` based on
device compute capability.
* Rename `MultiMatmulScheduler` to `AmpereMatmulScheduler`. Move it from
`multi_matmul.cpp` to `ampere_multi_matmul.h`
* Create `HopperMultiMatmulScheduler` in `hopper_multi_matmul.h`
* Update arch guard in `test_matmul_scheduler.cpp`

## Why?
Pros:
- For deprecation, we could remove lines in `MatmulScheduler` and delete
old files.
- Assume uniform compute capability in each `MultipleMatmulScheduler`.
- All reusable code would go matmul_utils.h and mma_util.h.

Cons:
- Older generations would have fewer capabilities, since most
development would go to the current generation.
- More lines of code.
- `MatmulParams` could contain parameters for different generations.
  • Loading branch information
rdspring1 authored Oct 25, 2024
1 parent 85c22a2 commit 896a28a
Show file tree
Hide file tree
Showing 7 changed files with 3,589 additions and 1,638 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/mark_aliases.cpp
${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp
${NVFUSER_SRCS_DIR}/scheduler/multi_matmul.cpp
${NVFUSER_SRCS_DIR}/scheduler/ampere_multi_matmul.cpp
${NVFUSER_SRCS_DIR}/scheduler/hopper_multi_matmul.cpp
${NVFUSER_SRCS_DIR}/scheduler/matmul_heuristic_plugin.cpp
${NVFUSER_SRCS_DIR}/scheduler/matmul_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp
Expand Down
1,524 changes: 1,524 additions & 0 deletions csrc/scheduler/ampere_multi_matmul.cpp

Large diffs are not rendered by default.

241 changes: 241 additions & 0 deletions csrc/scheduler/ampere_multi_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <scheduler/mma_utils.h>
#include <val_graph.h>
#include <val_graph_visitor.h>

namespace nvfuser {

// MmaOps in the scheduled tensor. Each one outputs a TensorView* which we call
// an mma_result. Each MmaOp will also have two input TensorViews which we call
// "ab" and "bb" since they are the immediate A and B operands and they contain
// broadcast dimensions. Again there can be multiple abs and multiple bbs in
// one fusion. These TensorViews are loaded from global memory tensors that we
// call "a" and "b" into shared memory tensors called acw_smem and bcw_smem.
// They are loaded from shared memory to register buffers we call "acr" and
// "bcr" ("cr" meaning "cache read" in this context).
//
// Putting this all together we have the following order for a simple matmul
//
// a -> acw_smem -> acr -> ... -> ab
// \ .
// mma_result -> ... -> dc -> d
// /
// b -> bcw_smem -> bcr -> ... -> bb
//
// The ... indicate that there might be other tensors involved in a prologue or
// epilogue section at that location.
//
// In this example there are two matmuls both using the same "a" operand:
//
// b1 -> bcw_smem1 -> bcr1 -> ... -> bb1
// \ .
// mma_result1
// / \ .
// a -> acw_smem -> acr -> ... -> ab ... -> dc -> d
// \ /
// mma_result2
// /
// b2 -> bcw_smem2 -> bcr2 -> ... -> bb2
//
// Note that there can be more than one output d and each one will have its own
// register cache dc.
//
// Split-K and smem epilogue unswizzling add two additional tensors for each
// mma in the fusion: splitk_sum and smem_epilogue.
//
// // No split-K, no smem epilogue unswizzling:
// mma_result -> ... -> dc -> d
// // split-K, no smem epilogue unswizzling:
// mma_result -> splitk_sum -> ... -> dc -> d
// // smem epilogue unswizzling, no split-K:
// mma_result -> smem_epilogue -> ... -> dc -> d
// // split-K and smem epilogue unswizzling:
// mma_result -> smem_epilogue -> splitk_sum -> ... -> dc -> d
//
// These additional tensors are added to each mma_result in the fusion.
//
// Each of the named tensors above is scheduled differently. We schedule them
// by building AbstractTensors for each tensor category; these are held in
// AmpereMultipleMatmulScheduler::schedules_.
// TODO: Inheret from SchedulerEntry
class AmpereMultipleMatmulScheduler {
public:
AmpereMultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params)
: fusion_(fusion),
params_(params),
id_model_(fusion, /*build_graphs=*/false) {
const auto device_prop = at::cuda::getCurrentDeviceProperties();
const int cc = device_prop->major * 10 + device_prop->minor;
NVF_ERROR(
cc >= 75 && cc < 90,
"This matmul scheduler is restricted to Ampere and Turing.");
}

void run();

private:
void cacheInputsAndOutputs();

void findPatterns();

void countDims();

void translatePatterns();

// Get tensor roles and id roles
// When there are multiple matmul patterns, we can have conflicting roles.
// For now we throw an error if this is the case.
// TODO: This should be checked in canScheduleCompileTime
void findRoles();

// Including current tensor naming convention for reference,
// this is very temporary and will change over time and
// in fact the whole body of this function will
// eventually be a set of utility functions for different
// sections of matmul(fusion) kernels, with
// each having its own build out to do.
//
// Current naming convention is based on the following formula:
//
// d = alpha * (a x b) + beta * c
//
// and is defined in the following way:
//
// operands assumed in global memory : a, b, c
//
// registers staging global load : ar, br (short for a/b read)
//
// shared mem cache of operands : acw_smem, bcw_smem (short for a/b
// cache_write smem)
//
// registers at shared memory load output : acr, bcr (short for a/b cache
// read)
//
// register tensor input to the actual mma op: ab, bb (short for a/b
// broadcasted)
//
// accumulator register: mma_result
// - mma_result is MmaOp output if there is epilogue
// - mma_result is dc (short for d cache) if there is no epilogue
//
// result in global memory: d

// Currently the support is for a, b, c and d as fusion inputs/outputs
// aka. no prolog fusion yet.
void defineOperandCaches();

void cacheOperandsToSmem(
const std::vector<TensorView*>& operands,
std::vector<TensorView*>& smem_operands,
int64_t vec_size);

// We add two LoadStore operators to the inputs of our fusions. The first
// one is for a read from global memory and the second one (below) is for a
// cache read. As an optimizaton, we avoid adding an operator if there's an
// existing LoadStoreOp present. Please note that for the second LoadStore
// we don't propagate the allocation domain, since the scheduler sets the
// allocation domain in the registers.
void addSetsForCacheReads(
const std::vector<TensorView*>& tv_smems,
std::vector<TensorView*>& tv_rs);

//! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer
//! to the new IdModel. This is necessary whenever we perform an operation
//! that creates a new TensorView, such as caching or rFactor
void updateIdModel();

//! Swizzle the M and N outer dimensions after makeTile has been called.
//! This updates outer_dim_roles if we introduce a new dimension, which can
//! happen if tv is missing a merged axis, in which case we skip merging after
//! the split. This is analogous to forwarding during transform propagation.
void swizzleBlockTiles(
TensorView* tv,
std::vector<MatmulDimRole>& outer_dim_roles);

//! This calls orig->cacheAfter() and also updates the permissive graph to
//! reflect the new IterDomain mappings
TensorView* cacheAfter(
TensorView* orig,
LoadStoreOpType op_type = LoadStoreOpType::Set,
CacheOp cache_op = CacheOp::AllLevels,
bool propagate_allocation_domain = false);

//! Do block tiling for a collection of TensorViews. The tensors should be
//! unscheduled before this method is called.
//! 1) Axes will be ordered according to canonicalDimOrdering, and then axes
//! with the same role will be merged.
//! 2) After that, we perform splits according to
//! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
//! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has
//! both M and N dimensions, we perform a 2D swizzle of the outer dimensions
//! Mo and No.
//! 4) Finally, we do a split-K split if the splitk_factor is not 1
std::vector<std::vector<MatmulDimRole>> blockTileTensors(
const std::vector<TensorView*>& tvs);

//! Schedule the loads of all operands from global memory to shared memory.
//! Starting from the basic tiled schedule, we swizzle the operand memory.
//! Note that the cache op and LoadStoreOpType are already set during
//! defineOperandCaches().
void scheduleOperandSmemStores();

void scheduleMmaOperands(
std::vector<TensorView*>& tvs,
const std::optional<MmaOperand> operand_type);

// MmaOperand contains only A and B. If tvs are outputs (i.e. not operands),
// then operand_type should be std::nullopt.
void scheduleMmaResults();

void schedulePrologues();

void scheduleOutputTensor(TensorView* c);

void scheduleEpilogue();

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void scheduleFusionInputsForEpilogue();

void scheduleSplitKSum();

void setUpInlining();

// NOTE: this should be called after acw_smem, acr, ..., ab, and mma_result
// transforms have been applied and inlining
void setUpCircularBuffering();

private:
Fusion* fusion_;
const MatmulParams* params_;
IdModel id_model_;
// Permissive graph of id_model_, which we modify at times using e.g.
// AbstractTensor.split or by mapping vals in cacheAfter and rFactor
ValGraph* graph_ = nullptr;
std::vector<mma_utils::MatmulPattern> patterns_;
mma_utils::DimRolesMap id_roles_;
mma_utils::TensorRolesMap tensor_roles_;
mma_utils::MatmulOperandInnerDims inner_dims_;

int64_t num_splitk_dims_ = 0, num_device_dims_ = 0, num_local_batch_dims_ = 0,
num_device_and_batch_dims_ = 0;

std::vector<std::pair<TensorView*, TensorView*>> cached_outputs_;

std::vector<ValGroup> canonical_dim_ordering_;

std::vector<TensorView*> as_, bs_, acw_smems_, bcw_smems_, acrs_, bcrs_, abs_,
bbs_, mma_results_, splitk_sums_, smem_epilogues_;
};

} // namespace nvfuser
Loading

0 comments on commit 896a28a

Please sign in to comment.