Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return KernelArgumentHolder instead of std::vector<at::Tensor> #3946

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

csarofeen
Copy link
Collaborator

No description provided.

Return KernelArgumentHolder instead of std::vector<at::Tensor>
@csarofeen csarofeen force-pushed the polymorphic_outs_step_8 branch from 3d3db12 to d88fd8e Compare February 22, 2025 17:23
Copy link

github-actions bot commented Feb 22, 2025

Review updated until commit ebe727b

Description

  • Replace std::vector<at::Tensor> with KernelArgumentHolder

  • Update tests and benchmarks to use KernelArgumentHolder

  • Enhance allocateOutputs and run methods to return KernelArgumentHolder


Changes walkthrough 📝

Relevant files
Enhancement
54 files
test_matmul.cpp
Update allclose checks to use KernelArgumentHolder             
+95/-70 
test_alias.cpp
Update allclose checks to use KernelArgumentHolder             
+74/-52 
test_matmul_scheduler.cpp
Update allclose checks to use KernelArgumentHolder             
+55/-40 
test_gpu3.cpp
Update allclose checks to use KernelArgumentHolder             
+55/-50 
test_resize.cpp
Update allclose checks to use KernelArgumentHolder             
+57/-56 
allocations.cpp
Change allocateOutputs return type to KernelArgumentHolder
+2/-2     
executor.cpp
Replace std::vector with KernelArgumentHolder
+44/-45 
executor.cpp
Replace std::vector with KernelArgumentHolder
+16/-16 
fusion_definition.cpp
Replace std::vector with KernelArgumentHolder
+19/-20 
executor_utils.cpp
Replace std::vector with KernelArgumentHolder
+9/-9     
utils.cpp
Replace std::vector with KernelArgumentHolder
+6/-14   
fusion_executor_cache.cpp
Replace std::vector with KernelArgumentHolder
+5/-7     
fusion_kernel_runtime.cpp
Replace std::vector with KernelArgumentHolder
+4/-5     
compiled_kernel.cpp
Replace std::vector with KernelArgumentHolder
+10/-9   
matmul.cpp
Update benchmarks to use KernelArgumentHolder                       
+6/-3     
test_sdpa_node.cpp
Update validateSdpaFwdOutputs to use KernelArgumentHolder
+4/-4     
validator.cpp
Update testValidate to use KernelArgumentHolder                   
+8/-4     
test_gpu_view.cpp
Update test cases to use KernelArgumentHolder                       
+3/-3     
test_preseg_passes.cpp
Update test cases to use KernelArgumentHolder                       
+4/-3     
test_no_op.cpp
Update test cases to use KernelArgumentHolder                       
+3/-2     
test_multidevice_pipeline.cpp
Update outputs to use KernelArgumentHolder                             
+5/-4     
test_loop_domain_scheduling.cpp
Update test cases to use KernelArgumentHolder                       
+2/-2     
gelu_backward.cpp
Update outputs to use KernelArgumentHolder                             
+2/-2     
test_replay.cpp
Update test cases to use KernelArgumentHolder                       
+2/-2     
test_allocation_order_inference.cpp
Update test cases to use KernelArgumentHolder                       
+3/-2     
test_swizzle.cpp
Update test cases to use KernelArgumentHolder                       
+3/-2     
lstm_cell.cpp
Update outputs to use KernelArgumentHolder                             
+2/-2     
transformer.cpp
Update outputs to use KernelArgumentHolder                             
+1/-1     
test_translate_mma.cpp
Update test cases to use KernelArgumentHolder                       
+1/-1     
executor_dispatch.cpp
Update run method to use KernelArgumentHolder                       
+2/-2     
test_segmentation.cpp
Update test cases to use KernelArgumentHolder                       
+2/-1     
test_alias_analysis.cpp
Update test cases to use KernelArgumentHolder                       
+1/-1     
executor.cpp
Update runWithInput method to use KernelArgumentHolder     
+1/-1     
test_gpu_indexing_ops.cpp
Update test cases to use KernelArgumentHolder                       
+1/-1     
test_embedding_node.cpp
Update test cases to use KernelArgumentHolder                       
+1/-1     
main.cpp
Update sinh_nvfuser to return Tensor from KernelArgumentHolder
+1/-1     
test_host_ir_integration.cpp
Update test cases to use KernelArgumentHolder                       
+1/-1     
main.cpp
Update sinh_nvfuser to return Tensor from KernelArgumentHolder
+1/-1     
test_combined_inner_outer_reduction.cpp
Update test cases to use KernelArgumentHolder                       
+1/-1     
utils.cpp
Update scheduleAndRun to use KernelArgumentHolder               
+1/-1     
test_multidevice_overlap.cpp
Update test cases to use KernelArgumentHolder                       
+1/-1     
run_nvfuser_tests.py
Update test timeout check                                                               
+1/-1     
executor.h
Update run method to use KernelArgumentHolder                       
+6/-7     
executor.h
Update run and runWithInput methods to use KernelArgumentHolder
+4/-5     
fusion_kernel_runtime.h
Update runWithInputs and runKernelWithInput methods to use
KernelArgumentHolder
+2/-2     
validator.h
Update testValidate to use KernelArgumentHolder                   
+2/-2     
fusion_executor_cache.h
Update runFusionWithInputs method to use KernelArgumentHolder
+1/-1     
allocations.h
Update allocateOutputs to return KernelArgumentHolder       
+1/-1     
executor_dispatch.h
Update run method to use KernelArgumentHolder                       
+2/-2     
executor.h
Update runWithInput method to use KernelArgumentHolder     
+1/-1     
executor_utils.h
Update validateVectorizedTensors to use KernelArgumentHolder
+1/-1     
compiled_kernel.h
Update run method to use KernelArgumentHolder                       
+1/-1     
utils.h
Update CGResultsPackage to use KernelArgumentHolder           
+1/-1     
tmem.md
Update test cases to use KernelArgumentHolder                       
+16/-16 
Tests
19 files
test_allocation_domain.cpp
Update tests to use KernelArgumentHolder                                 
+35/-19 
test_host_irs.cpp
Update tests to use KernelArgumentHolder                                 
+14/-14 
test_move_split_cat.cpp
Update tests to use KernelArgumentHolder                                 
+19/-19 
test_circular_buffering.cpp
Update tests to use KernelArgumentHolder                                 
+20/-11 
test_multidevice_sharding.cpp
Update tests to use KernelArgumentHolder                                 
+24/-12 
test_multidevice_lower_communication.cpp
Update tests to use KernelArgumentHolder                                 
+24/-12 
test_gpu1.cpp
Update tests to use KernelArgumentHolder                                 
+14/-12 
test_multidevice_tutorial.cpp
Update tests to use KernelArgumentHolder                                 
+19/-11 
test_rng.cpp
Update tests to use KernelArgumentHolder                                 
+16/-14 
test_gpu2.cpp
Update tests to use KernelArgumentHolder                                 
+11/-11 
test_mma.cpp
Update tests to use KernelArgumentHolder                                 
+9/-9     
test_matmul_aten_evaluation.cpp
Update tests to use KernelArgumentHolder                                 
+10/-10 
test_tutorial.cpp
Update tests to use KernelArgumentHolder                                 
+9/-9     
test_gpu_transpose.cpp
Update tests to use KernelArgumentHolder                                 
+6/-6     
test_gpu_outer_reduction.cpp
Update tests to use KernelArgumentHolder                                 
+5/-8     
test_multidevice_host_ir.cpp
Update tests to use KernelArgumentHolder                                 
+7/-6     
test_multidevice_transformer.cpp
Update tests to use KernelArgumentHolder                                 
+12/-6   
test_memory.cpp
Update tests to use KernelArgumentHolder                                 
+4/-4     
test_matmul_sass.cpp
Update tests to use KernelArgumentHolder                                 
+2/-2     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Consistent Use of `as()`

Ensure that the use of as<at::Tensor>() is consistent and necessary across all test cases. Verify that the conversion is required and does not introduce any unintended behavior.

  auto cg_outputs = ke.run({inputs.first, inputs.second});
  auto tref = atMatmul(
      inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
  NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
}

// Single batch dimension which is broadcast
TEST_P(MatmulTestWithLayout, AmpereMatmulBroadcastBatch) {
  NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);

  // Keep multiples of 8 to keep vectorizable.
  int M = 504, N = 136, K = 248;

  Fusion fusion;
  FusionGuard fg(&fusion);

  auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

  auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
  auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

  fusion.addInput(tv0);
  fusion.addInput(tv1);

  tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
  tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
  // Broadcast inputs to 1, M, 1, K and 1, 1, N, K
  tv0 = broadcast(tv0, {true, false, false, false});
  tv1 = broadcast(tv1, {true, false, false, false});
  auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

  fusion.addOutput(tv2);

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(128, 128, 32);
  gemm_tile.warp_tile = GemmTile(64, 64, 32);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 4};
  mparams.mma_macro = MmaMacro::Ampere_16_8_16;
  mparams.tile_sizes = gemm_tile;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = true;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  auto inputs = matmulAtInput3DTuring(M, N, K, layout);

  KernelExecutor ke;
  NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
      8,
      0,
      ke.compile(
          &fusion,
          {inputs.first, inputs.second},
          LaunchParams(),
          matmul_cparams));
  ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
  ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
      ke.compiledKernel()->kernel()));
  auto cg_outputs = ke.run({inputs.first, inputs.second});
  auto tref =
      atMatmul(
          inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout)
          .unsqueeze(0);
  NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
}

TEST_P(MatmulTestWithLayout, AmperePrologueFusionBroadcast) {
  NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);

  // Keep multiples of 8 to keep vectorizable.
  int M = 504, N = 136, K = 248;

  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeContigTensor(2, DataType::Half);
  auto tv1 = makeContigTensor(2, DataType::Half);
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
  tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
  auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

  fusion.addOutput(tv2);

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(128, 128, 32);
  gemm_tile.warp_tile = GemmTile(64, 64, 32);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 4};
  mparams.mma_macro = MmaMacro::Ampere_16_8_16;
  mparams.tile_sizes = gemm_tile;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = true;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  auto inputs = matmulAtInput2D(M, N, K, layout);

  KernelExecutor ke;
  NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
      8,
      0,
      ke.compile(
          &fusion,
          {inputs.first, inputs.second},
          LaunchParams(),
          matmul_cparams));
  ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
  ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
      ke.compiledKernel()->kernel()));
  auto cg_outputs = ke.run({inputs.first, inputs.second});
  auto tref = atMatmul(
      inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
  NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
}

TEST_P(MatmulTestWithLayout, AmpereProloguePointwise) {
  NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);

  // Keep multiples of 8 to keep vectorizable.
  int M = 504, N = 136, K = 248;

  Fusion fusion;
  FusionGuard fg(&fusion);

  auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

  auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
  auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

  fusion.addInput(tv0);
  fusion.addInput(tv1);

  tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
  tv0 = castOp(DataType::Half, sin(tv0));
  tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
  tv1 = castOp(DataType::Half, sin(tv1));
  auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

  fusion.addOutput(tv2);

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(128, 128, 32);
  gemm_tile.warp_tile = GemmTile(64, 64, 32);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 4};
  mparams.mma_macro = MmaMacro::Ampere_16_8_16;
  mparams.tile_sizes = gemm_tile;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = true;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  auto inputs = matmulAtInput3DTuring(M, N, K, layout);

  KernelExecutor ke;
  NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
      8,
      0,
      ke.compile(
          &fusion,
          {inputs.first, inputs.second},
          LaunchParams(),
          matmul_cparams));
  ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
  ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
      ke.compiledKernel()->kernel()));
  auto cg_outputs = ke.run({inputs.first, inputs.second});
  auto tref = atMatmul(
      inputs.first.sin().to(at::kFloat),
      inputs.second.sin().to(at::kFloat),
      layout);
  NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
}

TEST_P(MatmulTestWithLayout, AmpereMatmulBFloat16) {
  NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);

  // Keep multiples of 8 to keep vectorizable.
  int M = 504, N = 136, K = 248;

  Fusion fusion;
  FusionGuard fg(&fusion);

  auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

  auto tv0 = makeContigConcreteTensor(shapes.first, DataType::BFloat16);
  auto tv1 = makeContigConcreteTensor(shapes.second, DataType::BFloat16);

  fusion.addInput(tv0);
  fusion.addInput(tv1);

  tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
  tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
  auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

  fusion.addOutput(tv2);

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(128, 128, 32);
  gemm_tile.warp_tile = GemmTile(64, 64, 32);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 4};
  mparams.mma_macro = MmaMacro::Ampere_16_8_16;
  mparams.tile_sizes = gemm_tile;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = true;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  auto inputs = matmulAtInput3DTuring(M, N, K, layout, at::kBFloat16);

  KernelExecutor ke;
  NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
      8,
      0,
      ke.compile(
          &fusion,
          {inputs.first, inputs.second},
          LaunchParams(),
          matmul_cparams));
  ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
  ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
      ke.compiledKernel()->kernel()));
  auto cg_outputs = ke.run({inputs.first, inputs.second});
  auto tref = atMatmul(
      inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
  NVF_CHECK(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 0.0001, 0.0001));
}

// Matmul test for Ampere MMA: with pipelined gmem load
TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) {
  NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);

  // Keep multiples of 8 to keep vectorizable.
  int M = 504, N = 136, K = 248;
  REQUIRE_DEVICE_SMEM_SIZE(70 << 10, 0);

  // Gmem pipeline stage
  for (auto stage : {3, 4}) {
    Fusion fusion;
    FusionGuard fg(&fusion);

    auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

    auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
    auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

    fusion.addInput(tv0);
    fusion.addInput(tv1);

    tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
    tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
    auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

    fusion.addOutput(tv2);

    MatMulTileOptions gemm_tile;
    gemm_tile.cta_tile = GemmTile(128, 128, 32);
    gemm_tile.warp_tile = GemmTile(64, 64, 32);

    MatmulParams mparams;
    mparams.supported_vec_size = {8, 8, 4};
    mparams.mma_macro = MmaMacro::Ampere_16_8_16;
    mparams.tile_sizes = gemm_tile;
    mparams.tile_sizes = gemm_tile;
    mparams.async_gmem_load_operands = true;
    mparams.circular_buffer_options.circular_buffer_smem_write = true;
    mparams.circular_buffer_options.smem_circular_buffer_stage = stage;
    SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
        ->schedule(&fusion, &mparams);

    auto inputs = matmulAtInput3DTuring(M, N, K, layout);

    KernelExecutor ke;
    NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
        8,
        0,
        ke.compile(
            &fusion,
            {inputs.first, inputs.second},
            LaunchParams(),
            matmul_cparams));
    ASSERT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
    ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
        ke.compiledKernel()->kernel()));
    auto cg_outputs = ke.run({inputs.first, inputs.second});
Function Signature Changes

Review the changes in function signatures, particularly the introduction of KernelArgumentHolder in place of std::vector<at::Tensor>. Ensure that these changes do not break existing functionality and that all necessary adjustments have been made.

  return fusion_ != nullptr;
}

KernelArgumentHolder ExprEvalExecutor::run(
    KernelArgumentHolder& args,
    KernelArgumentHolder outputs) {
  FUSER_PERF_SCOPE("ExprEvalExecutor::run");

  if (isProfilerEnabled()) {
    NVF_CHECK(
        group_id_ >= 0,
        "An invalid segment id is passed to FusionProfiler!:",
        group_id_);
    SegmentProfiler& sprof = FusionProfiler::segment(group_id_);
    sprof.inputBytesAccessed(computeBytes(args));
    sprof.scheduler(toString(SchedulerType::ExprEval));
    sprof.startKernel();
  }

  NVF_ERROR(fusion_, "Need to compile before you can run.");
  // Bind fusion inputs
  auto expr_eval = executor_utils::bindInputs(args, fusion_.get());
  {
    NVF_ERROR(
        outputs.empty(),
        "Fusion executor is using expression evaluator,",
        " and expects that the outputs are not populated, which they were.");
    if (outputs.empty()) {
      for (const auto& out_val : fusion_->outputs()) {
        auto out_tensor =
            expr_eval.evaluate(out_val->as<TensorView>()).as<at::Tensor>();
        expr_eval.bind(out_val, out_tensor);
        outputs.push(out_tensor);
      }
    }
  }
  if (isProfilerEnabled()) {
    FusionProfiler::segment(group_id_).stopKernel();
    FusionProfiler::segment(group_id_).setDevice(args.getDeviceIndex());
  }
  return outputs;
}

namespace {
bool hasCpuScalarOutputs(Fusion* _fusion) {
  if (_fusion->exprs().empty()) {
    return false;
  }

  std::unordered_map<TensorView*, bool> tv_is_cpu_map;
  for (Expr* expr : StmtSort::getExprs(_fusion)) {
    bool has_cpu_scalar_input = false;
    bool has_cuda_input = false;
    for (Val* inp : expr->inputs()) {
      if (auto* inp_tv = dynamic_cast<TensorView*>(inp)) {
        if (inp_tv->isCpuScalar()) {
          has_cpu_scalar_input = true;
        } else {
          has_cuda_input = true;
          // Return early -- found atleast one CUDA input
          break;
        }
      }
    }
    if (!has_cuda_input && has_cpu_scalar_input) {
      // Expr is of the second category, and has all CPU scalar outputs
      for (Val* out : expr->outputs()) {
        if (auto* out_tv = dynamic_cast<TensorView*>(out)) {
          tv_is_cpu_map[out_tv] = true;
        }
      }
    }
  }

  bool has_any_cpu_output = std::any_of(
      _fusion->outputs().begin(),
      _fusion->outputs().end(),
      [&tv_is_cpu_map](Val* out) {
        return out->isA<TensorView>() && tv_is_cpu_map[out->as<TensorView>()];
      });
  return has_any_cpu_output;
}
} // namespace

bool KernelExecutor::supported(Fusion* fusion) {
  FUSER_PERF_SCOPE("KernelExecutor::supported");
  return !hasCpuScalarOutputs(fusion);
}

void KernelExecutor::compile(
    Fusion* fusion,
    const KernelArgumentHolder& args,
    const LaunchParams& launch_constraints,
    CompileParams compile_params,
    SchedulerType scheduler_type) {
  FUSER_PERF_SCOPE("KernelExecutor::compile");

  NVF_ERROR(
      supported(fusion),
      "KernelExecutor does not support the Fusion provided.");

  NVF_ERROR(
      !fusion->outputs().empty(), "No output found for this kernel, aborting.");

  auto device = c10::Device(c10::DeviceType::CUDA, args.getDeviceIndex());

  if (isProfilerEnabled()) {
    NVF_CHECK(
        group_id_ >= 0,
        "An invalid segment id is passed to FusionProfiler!:",
        group_id_);
    FusionProfiler::segment(group_id_).setDevice(device.index());
    FusionProfiler::segment(group_id_).startCompile();
  }

  //! Force index_type to int and disable magic zero if we detect that the
  //! kernel contains any TMA memory operations.
  std::vector<Expr*> exprs = fusion->exprs();
  bool has_cp_async_bulk = std::any_of(exprs.begin(), exprs.end(), [](Expr* e) {
    return ir_utils::isCpAsyncBulk(e);
  });

  // Disable magic zero if there are any TMA operations in Fusion
  if (has_cp_async_bulk) {
    compile_params.enable_magic_zero = false;
  }

  // Set the index type of compile params if not already set. If set,
  // make sure the compile param type is valid with the given kernel
  // arguments.
  auto arg_index_type = args.getSmallestIndexTypeOfArguments();
  if (compile_params.index_type.has_value()) {
    // If the int32 compilation is requested, but the arguments demand
    // int64, that's an error
    NVF_ERROR(
        !(compile_params.index_type.value() == PrimDataType::Int32 &&
          arg_index_type == PrimDataType::Int),
        "Compilation with int32 is requested but int64 is required for the arguments");
  } else {
    // If the given compile option doesn't specify the index type, and
    // the arguments require 64-bit indexing, we need to use 64-bit
    // indexing. Note that if the arg type is 32-bit, it doesn't mean
    // it's safe to use 32-bit for the whole kernel, so unless it's
    // specified through CompileParams, we do not use 32-bit indexing.
    compile_params.index_type = arg_index_type;
    compile_params.index_type = arg_index_type;
  }

  c10::DeviceGuard dg(device);

  NVF_ERROR(device.is_cuda(), "Provided device to CUDA fuser is the CPU.");
  auto properties = at::cuda::getDeviceProperties(device.index());
  // TODO: These properties should be set as part of the constructor so that it
  // can be const
  device_smem_limit_ = static_cast<int64_t>(properties->sharedMemPerBlockOptin);
  warp_size_ = properties->warpSize;

  // Lowered is needed to compute launch parameters as it uses the CA map. We
  // could modify that, but simply generating that part first.
  compiled_kernel_ = std::make_unique<CompiledKernel>(
      fusion,
      compile_params,
      device,
      scheduler_type,
      fusion_id_,
      concrete_id_,
      runtime_id_,
      group_id_,
      lowering_hooks_,
      post_lowering_hooks_);

  // TODO: pass block_size here;
  std::optional<int64_t> dynamic_smem = std::nullopt;
  std::optional<int64_t> block_size = std::nullopt;

  auto launch_params = launch_constraints;
  if (!args.empty()) {
    auto expr_eval =
        executor_utils::bindInputs(args, compiled_kernel_->lowered()->kernel());
    NVF_ERROR(compile_params.index_type.has_value());
    launch_params = computeLaunchParams(
        launch_constraints,
        expr_eval,
        warp_size_,
        compile_params.index_type.value());
    block_size = launch_params.nThreads();
    dynamic_smem = launch_params.smem();
    NVF_ERROR(block_size > 0, "launch param inferred block size < 0");
  }

  // Now that we have launch parameters we can compile the kernel. It's a bit
  // odd we need launch parameters for compilation, need to go back and check
  // why this is the case.
  compiled_kernel_->compile(launch_params.nThreads());

  // These should be nullopt at this point, but reset just in case
  resetCompiledKernelProperties();

  // If the dynamic shmem size is known, make sure the compiled kernel
  // has at least that size of dynamic shmem
  if (dynamic_smem.has_value()) {
    ensureAvailableDynamicSmemSize(dynamic_smem.value());
  }
  if (isProfilerEnabled()) {
    FusionProfiler::segment(group_id_).stopCompile();
  }
}

LaunchParams KernelExecutor::computeLaunchParams(
    const LaunchParams& launch_constraints,
    ExpressionEvaluator& expr_eval,
    const int64_t warp_size,
    DataType index_type) {
  FUSER_PERF_SCOPE("KernelExecutor::computeLaunchParams");
  NVF_ERROR(warp_size > 0, "WARP_SIZE should be larger than 0");

  LaunchParams launch_params;

  auto data_cache = compileTimeDataCache();

  auto lower = compiled_kernel_->lowered().get();
  if (compiled_kernel_->getUsedTVs().empty()) {
    compiled_kernel_->setUsedTVs();
  }
  auto& used_tvs = compiled_kernel_->getUsedTVs();

  auto parallel_binding_ids_entry =
      executor_utils::caching::ExecutorCompileTimeEntry<
          executor_utils::caching::ParallelBindingIterDomains>(
          data_cache, [&used_tvs, &lower]() {
            return std::make_unique<std::vector<IterDomain*>>(
                executor_utils::getParallelBindingsIterDomains(
                    lower, used_tvs));
          });
  auto& parallel_binding_ids = parallel_binding_ids_entry.get();

  auto parallel_iter_extent_entry =
      executor_utils::caching::ExecutorCompileTimeEntry<
          executor_utils::caching::ParallelIterExtentMap>(
          data_cache, [&parallel_binding_ids]() {
            return executor_utils::getParallelIterExtents(parallel_binding_ids);
          });
  auto& parallel_iter_extents = parallel_iter_extent_entry.get();

  const auto& simplified_parallel_iter_extents =
      lower->parallelDimensionMap().getMap();

  // TODO: Need to redesign this part a bit to
  //   find the right place to trigger evaluate
  if (expr_eval.precomputedValues()) {
    expr_eval.precomputedValues()->bindParallelExtents(
        parallel_iter_extents, launch_constraints);
    expr_eval.precomputedValues()->evaluate();
  }

  // If any dimension was set in launch constraints we need to run through
  // IterDomains that have been parallelized, and bind those values. Or make
  // sure if they could be inferred the inference matches what was set.
  for (auto& entry : parallel_iter_extents) {
    auto p_type = entry.first;
    if (launch_constraints.hasDim(p_type)) {
      auto parallel_extents = entry.second;
      for (auto extent : parallel_extents) {
        auto inferred_val = expr_eval.evaluate(extent);
        if (inferred_val.hasValue()) {
          // This value could have been inferred, make sure it was set right.
          bool valid =
              inferred_val.as<int64_t>() == launch_constraints.getDim(p_type) ||
              launch_constraints.getRawVal(p_type) == -1;
          if (!useFallback() && !valid) {
            TORCH_WARN_ONCE(
                "Cannot validate parallelization scheme, "
                "this may be due to mixed broadcast axes that are parallelized.");
          }
        } else if (!expr_eval.precomputedValues()) {
          expr_eval.bind(extent, launch_constraints.getDim(p_type));
        }
        if (!launch_params.hasDim(p_type)) {
          // Bind the launch constraint into our evaluation context
          launch_params.bind(launch_constraints.getDim(p_type), p_type);
          // Makes sure the p-types bound to evaluators are the
          //  final values that will become the actual launch
          //  param size to ensure accurate smem buffer size
          //  computation.
          expr_eval.bind(p_type, launch_constraints.getDim(p_type));
        }
      }
    }
  }

  // Run through the rest of the parallel IterDomains and infer their size
  for (auto [p_type, extent] : simplified_parallel_iter_extents) {
    FUSER_PERF_SCOPE("KernelExecutor::ParallelBindingResolution");
    auto val = expr_eval.evaluate(extent);
    NVF_ERROR(
        val.hasValue(),
        "Tried to evaluate the extent, ",
        extent->toInlineString(),
        " for the ptype: ",
        p_type,
        " to set launch bounds but could not.");

    if (val > 0) {
      expr_eval.bind(p_type, val);
      launch_params.bind(val.as<int64_t>(), p_type);
    }
  }

  // Re-run the integer machine with all
  //  the thread sizes now determined.
  if (expr_eval.precomputedValues()) {
    expr_eval.precomputedValues()->evaluate();
  }

  const auto kernel = compiled_kernel_->lowered()->kernel();
  const auto& kernel_summary = kernel->summary();

  // Calculate Dynamic Shared Memory Size
  // Add workspace for reduction and broadcast
  int64_t reduction_broadcast_workspace = 0;
  const bool has_workspace = kernel_summary.has_block_reductions ||
      kernel_summary.has_grid_reductions ||
      kernel_summary.has_block_broadcasts || kernel_summary.has_grid_broadcasts;
  if (has_workspace &&
      kernel_summary.largest_smem_data_type != DataType::Null) {
    // Not using nThreads here since it does not handle uninitialized value

    // TODO: here is an optimization opportunity since welford uses int64_t for
    // N while the data type is not neccessarily double. But it may need more
    // work on the alignment
    const int welford_factor =
        kernel_summary.has_block_welford || kernel_summary.has_grid_welford ? 3
                                                                            : 1;
    // in outer reduction, may group iteration domain, e.g. when vectorized.
    const int64_t grouped_iter_factor = kernel_summary.num_grouped_iterations;

    NVF_CHECK(
        !(kernel_summary.has_iter_grouped_reductions && welford_factor == 3),
        "can't have welford and iter grouped reductions at the same time! Should be handled by grouped welford!");

    reduction_broadcast_workspace =
        (int64_t)dataTypeSize(
            kernel_summary.largest_smem_data_type, index_type) *
        grouped_iter_factor * welford_factor * launch_params.bdimx() *
        launch_params.bdimy() * launch_params.bdimz();

    if (kernel_summary.has_outer_grouped_grid_welford) {
      reduction_broadcast_workspace = std::max(
          reduction_broadcast_workspace,
          (int64_t)kernel_summary.outer_grouped_grid_welford_largest_smem_size);
    }
  }

  const auto dynamic_smem_size = computeSharedMemory(
      expr_eval,
      kernel_summary.dynamic_smem_allocations,
      index_type,
      reduction_broadcast_workspace);

  // Check that requested smem size can be dynamically allocated.
  //  This check is only done once a kernel has been compiled, since
  //  maybe_available_dynamic_smem_ needs to be evaluated on
  //  a compiled kernel.
  if (compiled_kernel_->isCompiled()) {
    validateDynamicSmemSize(dynamic_smem_size);
  }

  launch_params.setSmem(dynamic_smem_size);

  return launch_params;
}

std::vector<GlobalBufferInfo> KernelExecutor::getIntermediateBufferInfo(
    ExpressionEvaluator& expr_eval,
    DataType index_type) {
  FUSER_PERF_SCOPE("KernelExecutor::getIntermediateBufferInfo");
  std::vector<GlobalBufferInfo> global_buffers;

  const auto kernel = compiled_kernel_->lowered()->kernel();
  const auto& kernel_summary = kernel->summary();

  for (auto alloc : kernel_summary.global_allocations) {
    NVF_ERROR(
        alloc->buffer()->isA<TensorView>(),
        "Cannot allocate global buffers that are not tensors.");
    auto tv = alloc->buffer()->as<TensorView>();
    if (tv->isFusionOutput()) {
      continue;
    }
    GlobalBufferInfo info;
    info.tv = tv;
    info.zero_init = alloc->zeroInit();
    info.resets_to_zero = alloc->resetsToZero();
    // TODO: Allocation size needs to consider both expanded domains
    // as well as halo. Currently, allocation of tensors with halo is
    // only supported by inferShapeOfIntermediate, whereas expanded
    // domains are only supported by inferShapeOfOutput. Until the
    // halo support is revisited, use the former for all tensors
    // unless expanded and the latter otherwise. This assumes there's
    // no expanded domains with halo, which is fine for now.
    const auto has_expanded_domains = std::any_of(
        tv->getMaybeAllocationDomain().begin(),
        tv->getMaybeAllocationDomain().end(),
        [](IterDomain* id) { return id->hasExpandedExtent(); });
    std::tie(info.sizes, info.strides) = has_expanded_domains
        ? inferShapeOfOutput(tv, expr_eval)
        : inferShapeOfIntermediate(tv, alloc, expr_eval);
    auto dtype = (tv->dtype() == DataType::Index ? index_type : tv->dtype());
    info.type = data_type_to_aten(dtype);

    // Remember the tensor buffer used for storing kernel profile
    if (isOptionEnabled(EnableOption::KernelProfile) &&
        tv == kernel->profile().getBuffer()) {
      info.is_profile_buffer = true;
    }

    global_buffers.emplace_back(info);
  }

  return global_buffers;
}

namespace {

// Make sure the index type of Kernel is valid
void validateIndexType(
    kir::Kernel* kernel,
    const CompileParams& compile_params) {
  NVF_ERROR(
      !compile_params.index_type.has_value() ||
          kernel->indexType() == compile_params.index_type.value(),
      "Kernel index type and compilation index type don't match. Kernel type: ",
      kernel->indexType(),
      ". Compilation index type: ",
      compile_params.index_type.value());
}

void validateCooperativeLaunch(
    CUfunction kernel,
    const LaunchParams& launch_params,
    int64_t device_index) {
  int num_blocks_per_SM = -1;
  auto block_size =
      launch_params.bdimx() * launch_params.bdimy() * launch_params.bdimz();
  NVFUSER_CUDA_SAFE_CALL(cuOccupancyMaxActiveBlocksPerMultiprocessor(
      &num_blocks_per_SM,
      kernel,
      (int)block_size,
      (size_t)launch_params.smem()));

  auto grid_size =
      launch_params.gdimx() * launch_params.gdimy() * launch_params.gdimz();
  auto max_active_blocks = num_blocks_per_SM *
      at::cuda::getDeviceProperties((c10::DeviceIndex)device_index)
          ->multiProcessorCount;
  NVF_ERROR(
      (int64_t)(max_active_blocks) >= grid_size,
      "Wanted to launch a cooperative kernel, however the number of blocks is greater than ",
      "what can be resident on the GPU at once. Need: ",
      grid_size,
      " (",
      launch_params.gdimx(),
      " * ",
      launch_params.gdimy(),
      " * ",
      launch_params.gdimz(),
      ") but limited to ",
      num_blocks_per_SM,
      " * ",
      at::cuda::getDeviceProperties(device_index)->multiProcessorCount);
}

// Dump fusion inputs and outputs as well as some useful fusion
// information. Note that inputs and outputs are those that are passed
// to KernelExecutor::runFusion, so outputs may not be given.
void dumpFusionArgs(
    int64_t fusion_id,
    const KernelArgumentHolder& args,
    const LaunchParams& launch_constraints,
    const CompileParams& compile_params,
    const KernelArgumentHolder& outputs) {
  debug() << "Arguments for fusion" << fusion_id << ":" << std::endl
          << "Inputs:" << std::endl;
  for (auto i : c10::irange(args.size())) {
    debug() << "  " << args[i] << std::endl;
  }
  debug() << "Outputs:" << std::endl;
  for (const auto& output : outputs) {
    debug() << PolymorphicValue_functions::toString(output) << std::endl;
  }
  debug() << launch_constraints.toString();
  debug() << "maxrregcount= " << compile_params.maxrregcount << std::endl;
}

// Dump arguments that are passed to a CUDA kernel call, which include
// the inputs and outputs of the fusion as well as temporary
// global-memory buffers. Unlike dumpFusionArgs, which dumps inputs
// and outputs passed to KernelExecutor::runFusion, this function
// dumps those that are passed to a CUDA kernel.
void dumpKernelArgs(
    const int64_t fusion_id,
    const int64_t group_id,
    const KernelArgumentHolder& args,
    size_t num_inputs,
    const KernelArgumentHolder& allocated_outputs,
    const KernelArgumentHolder& intermediates,
    const std::vector<GlobalBufferInfo>& intermediates_info) {
  using namespace PolymorphicValue_functions;
  debug() << "Arguments for fusion " << fusion_id << " group " << group_id
          << ":" << std::endl
          << "Inputs:" << std::endl;
  for (auto i : c10::irange(num_inputs)) {
    debug() << "  " << toString(args[i]) << std::endl;
  }
  debug() << "Outputs:" << std::endl;
  // note: add aliased outputs here.
  for (const auto& output : allocated_outputs) {
    debug() << "  " << PolymorphicValue_functions::toString(output)
            << std::endl;
  }
  debug() << "Intermediate global buffers:" << std::endl;
  for (const auto i : c10::irange(intermediates.size())) {
    const auto& zero_init = intermediates_info.at(i).zero_init;
    const auto& resets_to_zero = intermediates_info.at(i).resets_to_zero;
    debug() << "  " << PolymorphicValue_functions::toString(intermediates[i])
            << " is_zero_initialized: " << zero_init
            << " resets_to_zero: " << resets_to_zero << std::endl;
  }
}

} // namespace

void KernelExecutor::initializeExecutorEntry(
    ExecutorEntry& executor_entry,
    const KernelArgumentHolder& args,
    const LaunchParams& launch_constraints,
    const CompileParams& compile_params,
    const KernelArgumentHolder& output_args,
    DataType index_type) {
  FUSER_PERF_SCOPE("KernelExecutor::initializeExecutorEntry");

  ExpressionEvaluator expr_eval;
  evaluatorPrecomputedValues()->bindInputs(args);
  expr_eval.precomputedValues() = evaluatorPrecomputedValues().get();

  auto launch_params = computeLaunchParams(
      launch_constraints, expr_eval, warp_size_, index_type);

  for (const auto& entry : compiled_kernel_->kernel()->summary().validations) {
    NVF_CHECK(expr_eval.evaluate(entry.first).as<bool>(), entry.second);
  }

  executor_utils::validateVectorizedTensors(
      compiled_kernel_->kernel(),
      args,
      output_args,
      compileTimeDataCache(),
      expr_eval);

  executor_utils::validateCircularBuffering(
      compiled_kernel_->kernel(), expr_eval);

  executor_utils::validateIndexCasts(
      compiled_kernel_->kernel(), expr_eval, launch_params);

  // Check that a full warp exists in blockDim.x if the kernel contains
  // ElectSync predicate.
  constexpr int64_t warp_size = 32;
  NVF_ERROR(
      !compiled_kernel_->kernel()->summary().has_elect_sync_predicate ||
          launch_params.bdimx() >= warp_size,
      "This cuda kernel contains electSync predicate. "
      "Expected blockDim.x >= 32 but found ",
      launch_params.bdimx());

  std::vector<GlobalBufferInfo> output_info;

  if (output_args.empty()) {
    output_info = getBufferInfos(
        expr_eval, index_type, compiled_kernel_->kernel()->outputs());
  } else {
    // Need to save the information necessary for allocations as
    // future uses of this ExecutorEntry may not be provided with
    // allocated outputs
    for (const auto& output : output_args) {
      const auto& out_tensor = output.as<at::Tensor>();
      output_info.emplace_back(GlobalBufferInfo{
          .sizes = out_tensor.sizes().vec(),
          .strides = out_tensor.strides().vec(),
          .type = out_tensor.scalar_type()});
    }
  }

  auto intermediates = getIntermediateBufferInfo(expr_eval, index_type);

  // All information is gathered. Save it to ExecutorEntry
  executor_entry.launch_params = launch_params;
  executor_entry.outputs = output_info;
  executor_entry.intermediates = intermediates;
  executor_entry.init = true;
}

/// Copies the data, logical_size, and alloc_stride parameters to the
/// appropriate parts of entry.args[idx].
///
/// For GPU tensors, we pass a Tensor<type, rank, rank> struct (see
/// runtime/tensor.cu), where the rank describes the number of elements in the
/// shape and stride arrays. The actual shapes and strides are dynamic, but the
/// type and rank of the tensors are actually static (changing them would need
/// a new FusionDefinition). So we create the storage area for the
/// Tensor<t,r,r> during ::computeArgs, and then in this function we just
/// update that memory with the current values for the tensor's base address,
/// shape, and strides.
///
/// @param entry the entry we have previously setup for this fusion
/// @param idx the index into entry.args and related parallel arrays in the
///            entry.
/// @param idx_type_size generally sizeof(int32_t) or sizeof(int64_t); used for
///                      computing how large the arrays to copy are.
static void fillTensorArgMetadata(
    KernelExecutor::ExecutorEntry& entry,
    const PolymorphicValue& tensor_metadata,
    size_t idx,
    size_t idx_type_size) {
  void* data = tensor_metadata->*&TensorMetaData::data;
  // g++ has trouble inferring the types of more complicated fields through our
  // *& operators. Creating an `auto` alias as a temporary resolves this
  // problem.
#define TMD_ARRAY_REF(pv, field)                  \
  ({                                              \
    const auto& fld_tmp_ = pv->*&field;           \
    const c10::IntArrayRef& fld_aref_ = fld_tmp_; \
    fld_aref_;                                    \
  })
  const c10::IntArrayRef& shape =
      TMD_ARRAY_REF(tensor_metadata, TensorMetaData::logical_size);
  const c10::IntArrayRef& strides =
      TMD_ARRAY_REF(tensor_metadata, TensorMetaData::alloc_stride);
#undef TMD_ARRAY_REF

  // These are the three offsets we need to copy into.
  std::array<std::byte*, 3> offsets = {
      entry.args[idx].data(), // data ptr
      entry.args[idx].data() + sizeof(void*), // shape array
      // strides array:
      entry.args[idx].data() + sizeof(void*) + shape.size() * idx_type_size,
  };

  memcpy(offsets[0], &data, sizeof(void*));
  switch (idx_type_size) {
    case sizeof(int64_t): {
      // we use i64's for our sizes, so can use a simple copy here
      memcpy(offsets[1], shape.data(), shape.size() * sizeof(int64_t));
      memcpy(offsets[2], strides.data(), strides.size() * sizeof(int64_t));
    } break;
    case sizeof(int32_t): {
      // we need to cast per-element, so need a loop.
      // This case happens when the kernel uses 32bit indices. Since we
      // (specifically TensorMetaData) store indices in 64bit, we can't
      // directly copy our buffer into the args buffer. We thus have to
      // manually downcast each element to fit in the smaller buffer.
      for (size_t i = 0; i < shape.size(); ++i) {
        const int32_t shp = static_cast<int32_t>(shape[i]);
        memcpy(offsets[1] + i * sizeof(int32_t), &shp, sizeof(int32_t));
      }
      // In rare cases we have fewer strides than shapes
      for (size_t i = 0; i < strides.size(); ++i) {
        const int32_t strd = static_cast<int32_t>(strides[i]);
        memcpy(offsets[2] + i * sizeof(int32_t), &strd, sizeof(int32_t));
      }
    } break;
    default:
      NVF_CHECK(0, "Unhandled index type size");
      break;
  }
}

// set the arguments that we'll pass to cuL...

@csarofeen
Copy link
Collaborator Author

!test

@csarofeen
Copy link
Collaborator Author

!test

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me but I'll let others look before stamping. Should we next rename KernelArgumentHolder since it is no longer only arguments but also kernel outputs? I don't know of a good term off hand: maybe something like KernelIOContainer?

@@ -64,7 +64,7 @@ def get_python_tests(python_test_dir):

def get_test_timeout(test_name):
"""Return timeout in seconds for a given test"""
if test_name in ["test_nvfuser", "test_matmul", "test_ops"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does test_name include extension only for that one test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow you question.

test_nvfuser, test_matmul, and test_ops.py are the only three tests that take a very long time (well over 10 minutes), so we're just adjusting their timeouts accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants