Skip to content

Commit

Permalink
Merge branch 'llu/2d_inner_reduction_heuristics' into llu/clean_2d_in…
Browse files Browse the repository at this point in the history
…ner_reduction_heuristics
  • Loading branch information
liqiangxl authored Nov 8, 2024
2 parents bdeb737 + 68dd6c4 commit c339d72
Show file tree
Hide file tree
Showing 181 changed files with 4,904 additions and 4,020 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc")
set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party")

option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF)
option(NVFUSER_EXPLICIT_ERROR_CHECK "" OFF)
if (NVFUSER_EXPLICIT_ERROR_CHECK)
add_compile_definitions(NVFUSER_EXPLICIT_ERROR_CHECK)
endif()
option(NVFUSER_BUILD_WITH_ASAN "Build nvFuser with asan" OFF)

include(CMakeDependentOption)
Expand Down Expand Up @@ -545,6 +549,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_id_model.cpp
${NVFUSER_ROOT}/tests/cpp/test_indexing.cpp
${NVFUSER_ROOT}/tests/cpp/test_indexing_advanced.cpp
${NVFUSER_ROOT}/tests/cpp/test_inlining.cpp
${NVFUSER_ROOT}/tests/cpp/test_iter_visitor.cpp
${NVFUSER_ROOT}/tests/cpp/test_linked_hash_map.cpp
${NVFUSER_ROOT}/tests/cpp/test_loop_domain_scheduling.cpp
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/batch_norm_channels_first.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ static void setupBatchNorm(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_BatchNorm(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
NVF_ERROR(dtype == DataType::Float || dtype == DataType::Half);

Expand All @@ -102,7 +102,7 @@ static void NvFuserScheduler_BatchNorm(
std::vector<c10::IValue> aten_inputs(
{at_x, at_weight, at_bias, at_run_mean, at_run_var});

runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, aten_inputs);

benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) *
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/batch_norm_channels_first_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ static void setupBatchNorm_BWD(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_BatchNorm_BWD(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
NVF_ERROR(dtype == DataType::Float || dtype == DataType::Half);

Expand All @@ -115,7 +115,7 @@ static void NvFuserScheduler_BatchNorm_BWD(
std::vector<c10::IValue> aten_inputs(
{input, grad_out, weight, run_mean, run_var, save_mean, save_var});

runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, aten_inputs);

benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) *
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/batch_norm_channels_last.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ static void setupBatchNorm_nhwc(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_BatchNorm_nhwc(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
NVF_ERROR(dtype == DataType::Float || dtype == DataType::Half);

Expand All @@ -103,7 +103,7 @@ static void NvFuserScheduler_BatchNorm_nhwc(
std::vector<c10::IValue> aten_inputs(
{at_x, at_weight, at_bias, at_run_mean, at_run_var});

runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, aten_inputs);

benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) *
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/batch_norm_channels_last_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static void setupBatchNorm_nhwc_BWD(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_BatchNorm_nhwc_BWD(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
NVF_ERROR(dtype == DataType::Float || dtype == DataType::Half);

Expand All @@ -116,7 +116,7 @@ static void NvFuserScheduler_BatchNorm_nhwc_BWD(
std::vector<c10::IValue> aten_inputs(
{input, grad_out, weight, run_mean, run_var, save_mean, save_var});

runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, aten_inputs);

benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) *
Expand Down
24 changes: 12 additions & 12 deletions benchmarks/cpp/bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ static void setupDivMaxSoftmaxDropoutBackward(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_DivMaxSoftDropFwd(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
auto w = benchmark_state.range(0);
auto x = benchmark_state.range(1);
Expand All @@ -135,15 +135,15 @@ static void NvFuserScheduler_DivMaxSoftDropFwd(
std::vector<c10::IValue> at_inputs = {t0, t1};

auto bytes =
runBenchmarkIterations(benchmark_state, fusion_executor_cache, at_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, at_inputs);

benchmark_state.SetBytesProcessed(
bytes * int64_t(benchmark_state.iterations()));
}

static void NvFuserScheduler_DivMaxSoftDropBwd(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
auto w = benchmark_state.range(0);
auto x = benchmark_state.range(1);
Expand All @@ -162,7 +162,7 @@ static void NvFuserScheduler_DivMaxSoftDropBwd(
std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3};

auto bytes =
runBenchmarkIterations(benchmark_state, fusion_executor_cache, at_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, at_inputs);

// Some reason t1 isn't used, ignore it.
bytes -=
Expand Down Expand Up @@ -228,7 +228,7 @@ static void setupBiasDropoutAddLayernormFwd(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_BiasDropoutAddLayernormFwd(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
auto x = benchmark_state.range(0);
auto y = benchmark_state.range(1);
Expand All @@ -247,7 +247,7 @@ static void NvFuserScheduler_BiasDropoutAddLayernormFwd(
std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3, t4};

auto bytes =
runBenchmarkIterations(benchmark_state, fusion_executor_cache, at_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, at_inputs);

benchmark_state.SetBytesProcessed(
bytes * int64_t(benchmark_state.iterations()));
Expand Down Expand Up @@ -304,7 +304,7 @@ static void setupBiasDropoutAddLayernormBwd1(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_BiasDropoutAddLayernormBwd1(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
auto x = benchmark_state.range(0);
auto y = benchmark_state.range(1);
Expand All @@ -322,7 +322,7 @@ static void NvFuserScheduler_BiasDropoutAddLayernormBwd1(
std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3};

auto bytes =
runBenchmarkIterations(benchmark_state, fusion_executor_cache, at_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, at_inputs);

benchmark_state.SetBytesProcessed(
bytes * int64_t(benchmark_state.iterations()));
Expand Down Expand Up @@ -380,7 +380,7 @@ static void setupBiasDropoutAddLayernormBwd2(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_BiasDropoutAddLayernormBwd2(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
auto x = benchmark_state.range(0);
auto y = benchmark_state.range(1);
Expand All @@ -398,7 +398,7 @@ static void NvFuserScheduler_BiasDropoutAddLayernormBwd2(
std::vector<c10::IValue> at_inputs = {t4, t5, t1, t8};

auto bytes =
runBenchmarkIterations(benchmark_state, fusion_executor_cache, at_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, at_inputs);

benchmark_state.SetBytesProcessed(
bytes * int64_t(benchmark_state.iterations()));
Expand Down Expand Up @@ -438,7 +438,7 @@ static void setupBiasDropoutAddLayernormBwd3(Fusion* fusion, DataType dtype) {

static void NvFuserScheduler_BiasDropoutAddLayernormBwd3(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype) {
auto x = benchmark_state.range(0);
auto y = benchmark_state.range(1);
Expand All @@ -454,7 +454,7 @@ static void NvFuserScheduler_BiasDropoutAddLayernormBwd3(
std::vector<c10::IValue> at_inputs = {t0, t21};

auto bytes =
runBenchmarkIterations(benchmark_state, fusion_executor_cache, at_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, at_inputs);

benchmark_state.SetBytesProcessed(
bytes * int64_t(benchmark_state.iterations()));
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ static void setupBroadcast(Fusion* fusion, DataType dtype, int bcast_axis) {

static void NvFuserScheduler_Broadcast(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype,
int bcast_dim) {
auto bcast_size = benchmark_state.range(0);
Expand All @@ -74,7 +74,7 @@ static void NvFuserScheduler_Broadcast(

std::vector<c10::IValue> aten_inputs({t0, t1});

runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, aten_inputs);

benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) *
Expand Down
28 changes: 14 additions & 14 deletions benchmarks/cpp/gelu_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ static void NvFuserScheduler_GeluBackward_Compile(
&fusion, SchedulerType::PointWise, c10::ArrayRef<c10::IValue>(inputs));

for (auto _ : benchmark_state) {
FusionExecutor executor;
executor.compileFusion(&fusion, inputs, heuristic_params->lparams);
KernelExecutor ke;
ke.compile(&fusion, inputs, heuristic_params->lparams);
}
}

Expand All @@ -187,14 +187,14 @@ static void NvFuserScheduler_GeluBackward_RunFusion(
auto heuristic_params = SchedulerEntry::scheduleWith(
&fusion, SchedulerType::PointWise, c10::ArrayRef<c10::IValue>(inputs));

FusionExecutor executor;
executor.compileFusion(&fusion, inputs, heuristic_params->lparams);
KernelExecutor ke;
ke.compile(&fusion, inputs, heuristic_params->lparams);

C10_CUDA_CHECK(cudaDeviceSynchronize());

for (auto _ : benchmark_state) {
outputs = executor.runFusion(
c10::ArrayRef<c10::IValue>(inputs), heuristic_params->lparams);
outputs =
ke.run(c10::ArrayRef<c10::IValue>(inputs), heuristic_params->lparams);
C10_CUDA_CHECK(cudaDeviceSynchronize());
clearL2Cache();
}
Expand All @@ -218,11 +218,11 @@ static void NvFuserScheduler_GeluBackward_RunFusion_GpuOnly(
auto heuristic_params = SchedulerEntry::scheduleWith(
&fusion, SchedulerType::PointWise, c10::ArrayRef<c10::IValue>(inputs));

FusionExecutor executor;
executor.compileFusion(&fusion, inputs, heuristic_params->lparams);
KernelExecutor ke;
ke.compile(&fusion, inputs, heuristic_params->lparams);

runBenchmarkIterations(
benchmark_state, &executor, inputs, heuristic_params->lparams);
benchmark_state, &ke, inputs, heuristic_params->lparams);
}

BENCHMARK(NvFuserScheduler_GeluBackward_RunFusion_GpuOnly)
Expand All @@ -247,13 +247,13 @@ static void NvFuserScheduler_GeluBackward_RunFusion_CpuOnly(
auto heuristic_params = SchedulerEntry::scheduleWith(
&fusion, SchedulerType::PointWise, c10::ArrayRef<c10::IValue>(inputs));

FusionExecutor executor;
executor.setExecuteKernelFlag(false);
executor.compileFusion(&fusion, inputs, heuristic_params->lparams);
KernelExecutor ke;
ke.setExecuteKernelFlag(false);
ke.compile(&fusion, inputs, heuristic_params->lparams);

for (auto _ : benchmark_state) {
outputs = executor.runFusion(
c10::ArrayRef<c10::IValue>(inputs), heuristic_params->lparams);
outputs =
ke.run(c10::ArrayRef<c10::IValue>(inputs), heuristic_params->lparams);
}
}

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/gelu_backward_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static void setupGeluBackwardReduction(

static void NvFuserScheduler_GeluBackwardReduction(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
FusionExecutorCache* executor_cache,
DataType dtype,
int reduction_dim) {
auto reduction_size = benchmark_state.range(0);
Expand All @@ -112,7 +112,7 @@ static void NvFuserScheduler_GeluBackwardReduction(

std::vector<c10::IValue> aten_inputs = {aten_input_grad, aten_input_x};

runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
runBenchmarkIterations(benchmark_state, executor_cache, aten_inputs);

// inputs: gradient tensor + input tensor
// outputs: output, output_of_reduction
Expand Down
24 changes: 12 additions & 12 deletions benchmarks/cpp/heuristic_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using namespace nvfuser;

static auto getLayerBackwardNormRuntime(
std::unique_ptr<Fusion> fusion_ptr,
std::unique_ptr<FusionExecutorCache>& fec,
std::unique_ptr<FusionExecutorCache>& executor_cache,
std::vector<c10::IValue>& aten_inputs,
std::vector<int64_t>& shape,
std::vector<int64_t>& norm_shape) {
Expand Down Expand Up @@ -84,12 +84,12 @@ static auto getLayerBackwardNormRuntime(
auto aten_mean = std::get<1>(aten_results);
auto aten_rstd = std::get<2>(aten_results);

fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
executor_cache = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
aten_inputs = {
aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias};
auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
auto cg_outputs = executor_cache->runFusionWithInputs(aten_inputs);

return fec->getMostRecentKernelRuntime();
return executor_cache->getMostRecentKernelRuntime();
}

static void NvFuserScheduler_LayerNormBackward_HeuristicCache(
Expand All @@ -98,14 +98,14 @@ static void NvFuserScheduler_LayerNormBackward_HeuristicCache(
FusionGuard fg(fusion_ptr.get());

// PreAllocate
std::unique_ptr<FusionExecutorCache> fec;
std::unique_ptr<FusionExecutorCache> executor_cache;
std::vector<c10::IValue> aten_inputs;

std::vector<int64_t> shape{20, 100, 35, 67};
std::vector<int64_t> norm_shape{67};

auto runtime = getLayerBackwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
std::move(fusion_ptr), executor_cache, aten_inputs, shape, norm_shape);

KernelArgumentHolder args =
KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);
Expand All @@ -120,7 +120,7 @@ static void NvFuserScheduler_LayerNormBackward_HeuristicCache(

static auto getLayerForwardNormRuntime(
std::unique_ptr<Fusion> fusion_ptr,
std::unique_ptr<FusionExecutorCache>& fec,
std::unique_ptr<FusionExecutorCache>& executor_cache,
std::vector<c10::IValue>& aten_inputs,
std::vector<int64_t>& shape,
std::vector<int64_t>& norm_shape) {
Expand All @@ -141,11 +141,11 @@ static auto getLayerForwardNormRuntime(
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor aten_input = at::randn(shape, options);

fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
executor_cache = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
aten_inputs = {aten_input};
auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
auto cg_outputs = executor_cache->runFusionWithInputs(aten_inputs);

return fec->getMostRecentKernelRuntime();
return executor_cache->getMostRecentKernelRuntime();
}

static void NvFuserScheduler_LayerNormForward_HeuristicCache(
Expand All @@ -154,14 +154,14 @@ static void NvFuserScheduler_LayerNormForward_HeuristicCache(
FusionGuard fg(fusion_ptr.get());

// PreAllocate
std::unique_ptr<FusionExecutorCache> fec;
std::unique_ptr<FusionExecutorCache> executor_cache;
std::vector<c10::IValue> aten_inputs;

std::vector<int64_t> shape{20, 100, 35, 67};
std::vector<int64_t> norm_shape{67};

auto runtime = getLayerForwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
std::move(fusion_ptr), executor_cache, aten_inputs, shape, norm_shape);

KernelArgumentHolder args =
KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);
Expand Down
Loading

0 comments on commit c339d72

Please sign in to comment.