From bc16eeb4d10efc0de0a8832588ec5c217c7cf5dc Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 22 Feb 2025 15:19:58 -0800 Subject: [PATCH 1/6] Try to get scalar outputs. --- csrc/fusion.cpp | 8 +- csrc/fusion_segmenter.cpp | 233 +---------------------------- csrc/fusion_segmenter.h | 10 +- csrc/runtime/allocations.cpp | 23 ++- csrc/scheduler/expr_eval_sched.cpp | 29 ++-- 5 files changed, 40 insertions(+), 263 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index c66650068c9..f2e8285dc65 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -253,11 +253,9 @@ void Fusion::addInput(Val* input) { void Fusion::addOutputInternal(Val* output) { assertInContainer(output, "Cannot register output "); - NVF_CHECK( - output->isA(), - "Non-TensorView outputs are not supported at this point: ", - output->toString()); - output->as()->setMemoryType(MemoryType::Global); + if (output->isA()) { + output->as()->setMemoryType(MemoryType::Global); + } outputs_.push_back(output); output->setIsFusionOutput(true); diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index def9b58ea51..8c06b7b62fc 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -3980,10 +3980,8 @@ void SegmentCandidateFinder::buildInitialSegments() { // Initialize DAG, convert each expr to a segment group auto exprs = completeFusion()->exprs(); for (auto expr : exprs) { - if (!ir_utils::isScalarOp(expr)) { - auto new_group = segmented_fusion_->newGroup(expr); - expr2group.insert(std::make_pair(expr, new_group)); - } + auto new_group = segmented_fusion_->newGroup(expr); + expr2group.insert(std::make_pair(expr, new_group)); } // TODO(wujingyue): remove singleton groups that are forwarded. They are @@ -3992,11 +3990,6 @@ void SegmentCandidateFinder::buildInitialSegments() { // Create edges between the Exprs. Mark inputs and outputs of the fusion. for (auto expr : exprs) { - // No group created for scalar ops - if (ir_utils::isScalarOp(expr)) { - continue; - } - if (excluded_inp_unary_exprs_.has(expr)) { continue; } @@ -4019,12 +4012,6 @@ void SegmentCandidateFinder::buildInitialSegments() { continue; } - // No group created for scalar ops since they may need to be duplicated - // to avoid scalar edges. They are handled in resolveScalarsInGroup - if (inp->isScalar()) { - continue; - } - auto def_group = expr2group.at(inp->definition()); auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp); expr_group->producer_edges.push_back(new_edge); @@ -4081,34 +4068,10 @@ void SegmentCandidateFinder::resolveForwardedInputs() { continue; } - if (forwarded_input->isScalar()) { - // Scalar forwarded inputs will be resolved after this loop. - // resolveNonscalarForwardedInput resolves only non-scalar ones because - // consumer_edges of a scalar input is always empty due to - // `removeScalarEdges`. - continue; - } - - resolveNonscalarForwardedInput(forwarded_input); + resolveForwardedInput(forwarded_input); // aux_group will be removed from segmented_fusion_ by // cleanupForwardedInputs. } - - // Un-forward scalar inputs unconditionally. - for (SegmentedGroup* group : segmented_fusion_->groups()) { - std::vector forwarded_scalar_inputs; - for (Val* input_val : group->inputs()) { - if (!input_val->isFusionInput() && input_val->isScalar()) { - forwarded_scalar_inputs.push_back(input_val); - } - } - - group->input_vals = IterVisitor::getInputsTo(group->inputs()); - auto input_exprs = StmtSort::getExprsTo(forwarded_scalar_inputs); - // Insert those expressions at the beginning of the group - group->exprs_.insert( - group->exprs_.begin(), input_exprs.begin(), input_exprs.end()); - } } void SegmentCandidateFinder::findSegments() { @@ -4139,10 +4102,6 @@ void SegmentCandidateFinder::findSegments() { } } - // Remove all scalar edges since they do not represent actual - // dependency among segmented groups. - removeScalarEdges(); - // Run pre-merge heuristics MergeUpAndDownCast::run(this); segmented_fusion_->validateIfDebug(true); @@ -4482,7 +4441,7 @@ void SegmentCandidateFinder::forwardInputs() { forwarded_inputs.pushBack(uop->out()); } // Either way, `uop` is excluded from merging until - // `resolveNonscalarForwardedInput` adds it back to one of the segments. + // `resolveForwardedInput` adds it back to one of the segments. excluded_inp_unary_exprs_.pushBack(uop); } } @@ -4590,154 +4549,6 @@ void SegmentCandidateFinder::finalMerge() { } } -void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) { - std::vector to_visit; - std::unordered_set visited; - - const auto processTV = [&to_visit](TensorView* tv) { - for (auto id : TensorDomain::noReductions(tv->getMaybeRootDomain())) { - to_visit.push_back(id->getMaybeExpandedExtent()); - } - if (tv->domain()->hasRoot()) { - // traverse from root to logical and inspect all Expr attrs and outputs - std::vector all_vals; - for (const auto id_expr : StmtSort::getExprsBetween( - {tv->getRootDomain().begin(), tv->getRootDomain().end()}, - {tv->getLogicalDomain().begin(), - tv->getLogicalDomain().end()})) { - all_vals.insert( - all_vals.end(), id_expr->inputs().begin(), id_expr->inputs().end()); - all_vals.insert( - all_vals.end(), - id_expr->outputs().begin(), - id_expr->outputs().end()); - for (const auto attr : id_expr->attributes()) { - if (attr && attr->isVal()) { - all_vals.push_back(attr->asVal()); - } - } - for (const auto val : all_vals) { - if (val->isScalar()) { - to_visit.push_back(val); - } else if (const auto id = dynamic_cast(val)) { - to_visit.push_back(id->getMaybeExpandedExtent()); - } - } - } - } - }; - - // Segment TensorView inputs will have their logical extents available, so we - // avoid adding them as separate scalar inputs. - for (auto e : group->producer_edges) { - if (const auto tv = dynamic_cast(e->val)) { - for (auto id : TensorDomain::noReductions(tv->getLogicalDomain())) { - visited.insert(id->getMaybeExpandedExtent()); - } - } - } - - // Collect all scalar uses in the group - for (auto expr : group->exprs()) { - for (auto input : expr->inputs()) { - if (input->isScalar()) { - to_visit.push_back(input); - } else if (auto tv = dynamic_cast(input); tv && - std::none_of(group->producer_edges.begin(), - group->producer_edges.end(), - [&tv](SegmentedEdge* e) { - return e->val == tv; - })) { - // Intermediate group inputs (producer edges) will have their logical - // domain reassigned as the root domain, so there is no need to process - // them. Tensors computed inside this group will need processing, - // however, as their root->logical transforms must be computed in this - // group. - processTV(tv); - } - } - for (auto attr : expr->attributes()) { - auto attr_val = dynamic_cast(attr); - if (!attr_val) { - continue; - } - if (attr_val->isScalar()) { - to_visit.push_back(attr_val); - } else if (auto tv = dynamic_cast(attr_val)) { - processTV(tv); - } - } - for (auto output : expr->outputs()) { - // We must be able to compute output extents for expression, so here we - // ensure the scalars involved are all available to this group - if (auto tv = dynamic_cast(output)) { - processTV(tv); - } - } - } - - // Keep track of composite fusion inputs used in this group - std::unordered_set input_set; - for (auto inp : group->input_vals) { - input_set.insert(inp); - if (auto tv = dynamic_cast(inp)) { - for (IterDomain* id : - TensorDomain::noReductions(tv->getLogicalDomain())) { - // Extents of inputs will already be bound. This prevents adding them - // as redundant inputs. - input_set.insert(id->getMaybeExpandedExtent()); - } - } - } - - // Record and append all missing scalar exprs at the end. - std::vector exprs_to_add; - - // Do a stack based traversal of the scalar ops to avoid - // combinatorial duplication of exprs. - while (!to_visit.empty()) { - auto stack_top_val = to_visit.back(); - if (visited.count(stack_top_val)) { - to_visit.pop_back(); - } else if (stack_top_val->definition() == nullptr) { - // A scalar without def can be a scalar, a tensor dim, - // or a composite fusion input - // The first two cases are handled in finalize(), - // the last case needs to add new input_val to this group. - visited.insert(stack_top_val); - // If this is a composite fusion scalar input, make sure this group has it - if (stack_top_val->isFusionInput() && !input_set.count(stack_top_val)) { - group->input_vals.push_back(stack_top_val); - input_set.insert(stack_top_val); - } - to_visit.pop_back(); - } else { - // A scalar with an actual definition - auto definition_expr = stack_top_val->definition(); - bool all_inputs_visited = true; - // If any of the inputs are not visited, visit them first - for (auto input : definition_expr->inputs()) { - if (!visited.count(input)) { - all_inputs_visited = false; - to_visit.push_back(input); - } - } - // This node is ready to be visited - if (all_inputs_visited) { - // Collect the defining expr to insert into group - exprs_to_add.push_back(definition_expr); - visited.insert(stack_top_val); - to_visit.pop_back(); - } - } - } - - // Add all the defining expr to the group - for (auto expr : exprs_to_add) { - group->exprs_.push_back(expr); - } -} - SegmentedGroup* SegmentCandidateFinder::createInputGroup(Val* forwarded_input) { SegmentedGroup* group = segmented_fusion_->newGroup(); group->input_vals = IterVisitor::getInputsTo({forwarded_input}); @@ -4745,8 +4556,7 @@ SegmentedGroup* SegmentCandidateFinder::createInputGroup(Val* forwarded_input) { return group; } -void SegmentCandidateFinder::resolveNonscalarForwardedInput( - Val* forwarded_input) { +void SegmentCandidateFinder::resolveForwardedInput(Val* forwarded_input) { SegmentedGroup* aux_group = input2group_.at(forwarded_input); NVF_ERROR(aux_group->producer_edges.empty()); @@ -4784,30 +4594,6 @@ void SegmentCandidateFinder::resolveNonscalarForwardedInput( } } -void SegmentCandidateFinder::removeScalarEdges() { - // Remove all scalar edges between groups - // They may have been created by welford - // translation. - // we will not need them after scalar - // resolution - auto remove_scalar_edges_from_vec = [](std::vector& edges) { - edges.erase( - std::remove_if( - edges.begin(), - edges.end(), - [](SegmentedEdge* segmented_edge) { - return segmented_edge->val->isScalar(); - }), - edges.end()); - }; - - remove_scalar_edges_from_vec(edges()); - for (auto group : groups()) { - remove_scalar_edges_from_vec(group->producer_edges); - remove_scalar_edges_from_vec(group->consumer_edges); - } -} - void SegmentCandidateFinder::finalize() { // Remove unconnected groups groups().erase( @@ -4830,11 +4616,6 @@ void SegmentCandidateFinder::finalize() { // Finalize connections between segmented groups segmented_fusion_->finalize(); - // Resolve all the scalar expressions needed in each group - for (auto group : segmented_fusion_->groups()) { - resolveScalarsInGroup(group); - } - for (auto group : segmented_fusion_->groups()) { revertPrivatizedUpcast(group); } @@ -4959,10 +4740,6 @@ void SegmentedFusion::validateDisjoint() const { } for (auto expr : group->exprs()) { - // Allow scalar exprs to exist in multiple groups - if (ir_utils::isScalarOp(expr)) { - continue; - } NVF_ERROR( exprs.insert(expr).second, "Duplicate expression detected: ", diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index 67d68acd2c0..afc7cf807f6 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -636,10 +636,6 @@ class SegmentCandidateFinder { //! produces the consumer. void finalMerge(); - //! Duplicate and add all exprs producing the used - //! scalar values in group - void resolveScalarsInGroup(SegmentedGroup* group); - //! Duplicate and add all exprs from fusion inputs to `forwarded_input` into //! the group, to complete inputs. These expressions are simply unary ops of //! inputs that we want to recompute for each segment, instead of computing @@ -652,7 +648,7 @@ class SegmentCandidateFinder { //! If we segmented on tv1, we would be producing an output for tv1 for 2 //! groups that have tv3 or tv4, instead we could easily recompute tv1 from //! tv0. - void resolveNonscalarForwardedInput(Val* forwarded_input); + void resolveForwardedInput(Val* forwarded_input); void resolveForwardedInputs(); @@ -660,10 +656,6 @@ class SegmentCandidateFinder { // between fusion inputs and `forwarded_input`. SegmentedGroup* createInputGroup(Val* forwarded_input); - //! Remove all scalar edges in group - //! (TODO: need structure better so we don't have to do this) - void removeScalarEdges(); - //! Utility function to merge a vector of groups in one step, //! need to check for DAG condition before using this method SegmentedGroup* mergeAllGivenGroups( diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 0a172c38201..62182bd7890 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -41,15 +41,32 @@ KernelArgumentHolder inferOutputSizes( output_tensor_proxies.setDeviceIndex(args.getDeviceIndex()); for (Val* output : fusion->outputs()) { - NVF_ERROR( - output->isA(), - "Cannot allocate outputs that are not tensors."); + if (output->isA()) { auto output_tv = output->as(); const auto& [sizes, strides] = inferShapeOfOutput(output_tv, expr_eval); const auto dtype = (output_tv->dtype() == DataType::Index) ? data_type_to_aten(arg_index_type) : data_type_to_aten(output_tv->dtype()); output_tensor_proxies.pushTensorProxy(sizes, strides, dtype); + } else if(output->isScalar()){ + switch(std::get(output->dtype().type)){ + case DataType::Int: + case DataType::Int32: + output_tensor_proxies.push(PolymorphicValue(0LL)); + break; + case DataType::Double: + case DataType::Float: + output_tensor_proxies.push(PolymorphicValue(0.0)); + break; + case DataType::Bool: + output_tensor_proxies.push(PolymorphicValue(false)); + break; + default: + NVF_ERROR("Output type not supported: ", output->toString()); + } + } else { + NVF_ERROR("Output type not supported: ", output->toString()); + } } return output_tensor_proxies; } diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 684beae07fa..c59e8f258bb 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -44,31 +44,24 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { return true; } - auto exprs = fusion->exprs(); - if (exprs.size() != 1) { - scheduler_debug_utils::canScheduleRejectReason( - schedulerType(), "Fusion must contain only a single expression."); - return false; - } + auto expr_check = [](Expr* expr) { + return expr->isOneOf() || + (expr->isOneOf() && + !isOptionDisabled(DisableOption::MatmulExprEval)) || + ir_utils::isScalarOp(expr); + }; - if (exprs.front()->isOneOf()) { - return true; - } + auto exprs = fusion->exprs(); - if (exprs.front()->isOneOf()) { - if (isOptionDisabled(DisableOption::MatmulExprEval)) { + for (auto expr : exprs) { + if (!expr_check(expr)) { scheduler_debug_utils::canScheduleRejectReason( - schedulerType(), - "Matmul ATen evaluation was disabled by NVFUSER_DISABLE=matmul_expr_eval"); + "Expr not supported in ExprEvalScheduler:", expr->toString()); return false; } - return true; } - scheduler_debug_utils::canScheduleRejectReason( - schedulerType(), - "Fusion must contain only a single expression of type MatmulOp/LinearOp/SdpaFwdOp/SdpaBwdOp"); - return false; + return true; } void ExprEvalScheduler::schedule( From 67a53a1aede23e0cdd2799b84ce0bd43206c08ae Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 22 Feb 2025 17:58:15 -0800 Subject: [PATCH 2/6] Get ExprEvalExec working, errors and segfaults in tests. --- csrc/fusion_segmenter.cpp | 39 ++++++++++++++++++++ csrc/runtime/allocations.cpp | 16 ++++----- csrc/runtime/executor.cpp | 9 ++--- csrc/scheduler/registry.cpp | 12 +++++++ tests/cpp/test_gpu2.cpp | 10 +++--- tests/cpp/validator.cpp | 69 +++++++++++++++++++++++++++--------- tests/cpp/validator.h | 2 +- 7 files changed, 123 insertions(+), 34 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 8c06b7b62fc..eb391c9f2dc 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1997,6 +1997,42 @@ std::pair> SegmentedFusion::makeFusion( return std::make_pair(complete_to_segment_map, std::move(fusion_segment)); } +namespace { +void simplifyConstantScalars(Fusion* fusion) { + FUSER_PERF_SCOPE("simplifyConstantScalars"); + + std::unordered_map replace_const_scalars; + + // Find all scalar expressions with constant inputs + for (auto expr : fusion->exprs()) { + auto has_const = + std::any_of(expr->inputs().begin(), expr->inputs().end(), [](Val* v) { + return v->isConstScalar(); + }); + auto has_non_const = + std::any_of(expr->inputs().begin(), expr->inputs().end(), [](Val* v) { + return !v->isConstScalar(); + }); + if (!(has_const && has_non_const)) { + continue; + } + for (auto const_scalar : ir_utils::filterByType(expr->inputs())) { + if (!const_scalar->isConstScalar()) { + continue; + } + // No definition + if (const_scalar->isConst()) { + continue; + } + auto value = const_scalar->evaluate(); + replace_const_scalars[const_scalar] = + IrBuilder::create(value, const_scalar->getDataType().value()); + } + } + ir_utils::replaceValue(fusion, replace_const_scalars); +} +} // namespace + std::unique_ptr SegmentCandidateFinder::segment( const Fusion* fusion, const KernelArgumentHolder& inputs, @@ -3944,6 +3980,9 @@ SegmentCandidateFinder::SegmentCandidateFinder( SegmentCandidateFinderOptions options, bool multi_device) : options_(options), runtime_inputs_(inputs) { + // Remove constant scalar expressions so they don't get segmented to the expr + // eval executor. + simplifyConstantScalars(fusion.get()); NVF_ERROR( !options_.only_segment_resharding_exprs || (!options_.run_translate_welford && diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 62182bd7890..d2ed6cb6ef8 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -42,14 +42,14 @@ KernelArgumentHolder inferOutputSizes( for (Val* output : fusion->outputs()) { if (output->isA()) { - auto output_tv = output->as(); - const auto& [sizes, strides] = inferShapeOfOutput(output_tv, expr_eval); - const auto dtype = (output_tv->dtype() == DataType::Index) - ? data_type_to_aten(arg_index_type) - : data_type_to_aten(output_tv->dtype()); - output_tensor_proxies.pushTensorProxy(sizes, strides, dtype); - } else if(output->isScalar()){ - switch(std::get(output->dtype().type)){ + auto output_tv = output->as(); + const auto& [sizes, strides] = inferShapeOfOutput(output_tv, expr_eval); + const auto dtype = (output_tv->dtype() == DataType::Index) + ? data_type_to_aten(arg_index_type) + : data_type_to_aten(output_tv->dtype()); + output_tensor_proxies.pushTensorProxy(sizes, strides, dtype); + } else if (output->isScalar()) { + switch (std::get(output->dtype().type)) { case DataType::Int: case DataType::Int32: output_tensor_proxies.push(PolymorphicValue(0LL)); diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index 57fd91224e8..5a455a7f4a0 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -62,6 +62,8 @@ bool ExprEvalExecutor::supported(Fusion* fusion) { } void ExprEvalExecutor::compile(Fusion* fusion) { + std::cout << "Using ExprEvalExecutor" << std::endl; + fusion->printMath(); FUSER_PERF_SCOPE("ExprEvalExecutor::compile"); if (isProfilerEnabled()) { FusionProfiler::segment(group_id_).startCompile(); @@ -105,10 +107,9 @@ KernelArgumentHolder ExprEvalExecutor::run( " and expects that the outputs are not populated, which they were."); if (outputs.empty()) { for (const auto& out_val : fusion_->outputs()) { - auto out_tensor = - expr_eval.evaluate(out_val->as()).as(); - expr_eval.bind(out_val, out_tensor); - outputs.push(out_tensor); + auto evaled_out = expr_eval.evaluate(out_val); + expr_eval.bind(out_val, evaled_out); + outputs.push(evaled_out); } } } diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 2c002b7ecbe..6e031a17cc9 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -143,6 +143,18 @@ bool canSchedule( std::unique_ptr scheduler = SchedulerEntry::makeSchedulerInstance(scheduler_type); + if (std::any_of( + fusion->outputs().begin(), + fusion->outputs().end(), + [](Val* v) { return v->isScalar(); }) && + (scheduler_type != SchedulerType::ExprEval)) { + scheduler_debug_utils::canScheduleMessage( + "***Rejected*** scheduler ", + scheduler_type, + " cannot accept scalar outputs"); + return false; + } + if (!skip_compile_time_checks && !scheduler->canScheduleCompileTime(fusion)) { return false; } diff --git a/tests/cpp/test_gpu2.cpp b/tests/cpp/test_gpu2.cpp index 1e327fc4c73..cb030a34631 100644 --- a/tests/cpp/test_gpu2.cpp +++ b/tests/cpp/test_gpu2.cpp @@ -4784,7 +4784,7 @@ TEST_F(NVFuserTest, FusionDAGMerging_CUDA) { NVF_CHECK(fusion_segments->groups().size() <= 4); } -TEST_F(NVFuserTest, FusionDAGScalarMerging_CUDA) { +TEST_F(NVFuserTest, FusionScalarProcessing_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -4794,14 +4794,16 @@ TEST_F(NVFuserTest, FusionDAGScalarMerging_CUDA) { fusion->addInput(tv0); fusion->addInput(i0); + // ExprEvalExecutor will take these scalar ops auto i1 = add(i0, IrBuilder::create(1.0)); auto i2 = mul(i1, i1); auto i3 = add(i2, i1); - // Branch 0 + // Kernel 0 auto tv1 = sum(tv0, {0}); // 0 auto tv2 = add(tv1, i2); - // Branch 1 + + // Kernel 1 auto tv3 = sum(tv2, {0}); // 1 auto tv4 = add(tv3, i3); @@ -4824,7 +4826,7 @@ TEST_F(NVFuserTest, FusionDAGScalarMerging_CUDA) { executor_cache.getMostRecentKernelRuntime() ->fusionSegments() ->groups() - .size() == 2, + .size() == 3, "segmentation didn't happen as expected"); testValidate(executor_cache.fusion(), outputs, {t0, s0}, __LINE__, __FILE__); diff --git a/tests/cpp/validator.cpp b/tests/cpp/validator.cpp index 3f29b283583..758fd4388a3 100644 --- a/tests/cpp/validator.cpp +++ b/tests/cpp/validator.cpp @@ -16,7 +16,7 @@ void testValidate( Fusion* fusion, const KernelArgumentHolder& fusion_outputs, const KernelArgumentHolder& aten_inputs, - std::vector aten_outputs, + const KernelArgumentHolder& aten_outputs, int line_number, const char* file_name, std::string err_msg, @@ -41,12 +41,6 @@ void testValidate( auto reduction_sizes = ReductionSizeMapper::computeReductionSizes(fusion, expr_eval); - if (aten_outputs.empty()) { - for (Val* out : non_hidden_outputs) { - aten_outputs.emplace_back(expr_eval.evaluate(out).as()); - } - } - NVF_ERROR( fusion_outputs.size() == aten_outputs.size(), "Number of outputs don't match: ", @@ -79,15 +73,35 @@ void testValidate( for (auto i : c10::irange(non_hidden_outputs.size())) { Val* out = non_hidden_outputs[i]; - NVF_ERROR(out->isA()); - TensorView* out_tv = out->as(); + if (!out->isA()) { + if (fusion_outputs[i].is()) { + NVF_ERROR( + aten_outputs[i].is(), + "Validation failed mismatched types."); + NVF_ERROR( + abs(fusion_outputs[i].as() - aten_outputs[i].as()) < + 1e-5, + "Validation failed ", + fusion_outputs[i].as(), + " != ", + aten_outputs[i].as()); + } else { + NVF_ERROR( + PolymorphicValue_functions::isSame( + fusion_outputs[i], aten_outputs[i]), + "Output, ", + i, + " mismatch: ", + PolymorphicValue_functions::toString(fusion_outputs[i]), + " != ", + PolymorphicValue_functions::toString(aten_outputs[i])); + } + } - NVF_ERROR( - fusion_outputs[i].is(), - "Fusion output is not a tensor at index ", - i); - const at::Tensor& fusion_output_tensor = fusion_outputs[i].as(); - const at::Tensor& aten_output_tensor = aten_outputs[i]; + const at::Tensor fusion_output_tensor = fusion_outputs[i].as(); + const at::Tensor aten_output_tensor = aten_outputs[i].as(); + + TensorView* out_tv = out->as(); NVF_ERROR( reduction_sizes.count(out_tv), @@ -98,7 +112,7 @@ void testValidate( NVF_ERROR( aten_output_tensor.dim() == fusion_output_tensor.dim() && - fusion_output_tensor.dim() == + aten_output_tensor.dim() == static_cast( TensorDomain::noReductions(out_tv->getLogicalDomain()) .size()), @@ -162,11 +176,32 @@ void testValidate( std::string err_msg, const LaunchParams& lparams, const ValidationConstants& tolerances) { + std::vector non_hidden_outputs; + std::copy_if( + fusion->outputs().begin(), + fusion->outputs().end(), + std::back_inserter(non_hidden_outputs), + [fusion](Val* out) { + // Returns true when `out` is **not** an aliased output that's hidden + // from integration. Hidden outputs won't show up in `fusion_outputs` + // for us to compare, so we skip them. + return !fusion->getOutputAlias(out).hide_output; + }); + + auto expr_eval = bindInputsAndLaunchParams(fusion, aten_inputs, lparams); + + KernelArgumentHolder aten_outputs; + if (aten_outputs.empty()) { + for (Val* out : non_hidden_outputs) { + aten_outputs.push(expr_eval.evaluate(out)); + } + } + testValidate( fusion, fusion_outputs, aten_inputs, - /*aten_outputs=*/{}, + aten_outputs, line_number, file_name, err_msg, diff --git a/tests/cpp/validator.h b/tests/cpp/validator.h index e29caa5983d..4008eed179b 100644 --- a/tests/cpp/validator.h +++ b/tests/cpp/validator.h @@ -36,7 +36,7 @@ void testValidate( Fusion* fusion, const KernelArgumentHolder& fusion_outputs, const KernelArgumentHolder& aten_inputs, - std::vector aten_outputs, + const KernelArgumentHolder& aten_outputs, int line_number, const char* file_name, std::string err_msg = "", From 0c376ae77ff669159ef89a157ee62122dd099d32 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 23 Feb 2025 08:18:12 -0800 Subject: [PATCH 3/6] Consume no ops in expr_eval. --- csrc/scheduler/expr_eval_sched.cpp | 36 +++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index c59e8f258bb..2a7a64e70c6 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -30,6 +30,40 @@ bool allOutputsArePointerArithmetics(Fusion* fusion) { return root != nullptr && root->isFusionInput(); }); } + +bool isNoOp(Expr* expr) { + if (expr->isA() && + (expr->as()->opType() == LoadStoreOpType::Set || + expr->as()->opType() == LoadStoreOpType::SegmenterSet)) { + return true; + } + if (ir_utils::isReductionOp(expr)) { + for (auto out_tv : ir_utils::filterByType(expr->outputs())) { + const std::vector& logical_dom = + TensorDomain::noReductions(out_tv->getLogicalDomain()); + const bool non_zero_reduction = std::any_of( + logical_dom.begin(), logical_dom.end(), [](IterDomain* id) { + return !( + id->extent()->isConstScalar() && + id->extent()->evaluate().as() == 0); + }); + if (non_zero_reduction) { + return false; + } + } + return true; + } + if (expr->isOneOf< + SqueezeOp, + BroadcastOp, + SliceOp, + CatOp, + ViewOp, + RepeatOp>()) { + return true; + } + return false; +} } // namespace // Check if the fusion has a single MatmulOp/LinearOp node @@ -48,7 +82,7 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { return expr->isOneOf() || (expr->isOneOf() && !isOptionDisabled(DisableOption::MatmulExprEval)) || - ir_utils::isScalarOp(expr); + ir_utils::isScalarOp(expr) || isNoOp(expr); }; auto exprs = fusion->exprs(); From ec8387fe658e02abdfeb460147157557bb38033f Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 23 Feb 2025 08:19:10 -0800 Subject: [PATCH 4/6] Deduplicate inputs from segmenter. Allow ExprEval to take segmenter set ops. --- csrc/fusion_segmenter.cpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index eb391c9f2dc..bd23cbfd29d 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -2636,6 +2636,16 @@ SchedulerType tryMerge( "\n**Segmenter** Considering fusion:\n", segmented_fusion->completeFusion()); if (tryingToMergeSegmenterSet(segmented_fusion->completeFusion())) { + if (Schedule::canSchedule( + SchedulerType::ExprEval, + segmented_fusion->completeFusion(), + runtime_info)) { + scheduler_debug_utils::canScheduleMessage( + "***Accepted*** as: ", SchedulerType::ExprEval); + return SchedulerType::ExprEval; + } + scheduler_debug_utils::canScheduleMessage( + "***Rejected*** failed tryingToMergeSegmenterSet"); return SchedulerType::None; } return Schedule::proposeHeuristics( @@ -4207,6 +4217,26 @@ void SegmentCandidateFinder::findSegments() { finalize(); + for (auto group : groups()) { + // I don't understand why but we're getting some duplicate inputs for the + // following segment: Group: pointwise{0} Inputs: + // T0_g_float[iS0{2}, iS1{3}] + // Outputs: + // T1_g_float[iS2{2}, iS3{3}] + + // %kernel_math { + // T1_g_float[iS2{2}, iS3{3}] + // = T0_g_float[iS0{2}, iS1{3}] + // + T0_g_float[iS0{2}, iS1{3}]; + // } // %kernel_math + // It's listing T0 as an input twice. For now WAR with deduplicating the + // inputs. + // + // TODO: Figure out what's happening in segmentation to make sure this + // doesn't happen. + group->input_vals = VectorOfUniqueEntries(group->input_vals).vector(); + } + // NVF_THROW("TMP"); // Do sanity check on the final graph. segmented_fusion_->validate(/*require_disjoint=*/false); From 4f0a491536611d4f9d60f4a4fec249edfee00151 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 23 Feb 2025 08:30:08 -0800 Subject: [PATCH 5/6] Debugging CombinedSchedulerTest.LayerNormBackward/dtype___half_batch_216_hidden_65536 in fusion_cache_utils.cpp. --- csrc/runtime/fusion_cache_utils.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/csrc/runtime/fusion_cache_utils.cpp b/csrc/runtime/fusion_cache_utils.cpp index 5546ce81fac..876d8cfdd51 100644 --- a/csrc/runtime/fusion_cache_utils.cpp +++ b/csrc/runtime/fusion_cache_utils.cpp @@ -171,6 +171,10 @@ void prepareRuntimeOrder( } } + for (auto group : segmented_fusion->groups()) { + std::cout << "Group: " << toString(group) << std::endl; + group->getFusion()->print(); + } // Keep track of groups that has run std::vector group_ran(segmented_fusion->groups().size(), false); @@ -181,6 +185,9 @@ void prepareRuntimeOrder( // Find the first segment with all inputs available to run for (const size_t group_i : c10::irange(segmented_fusion->groups().size())) { + if(group_ran[group_i]){ + continue; + } auto& group = segmented_fusion->groups()[group_i]; if (group_ran[group_i]) { continue; @@ -192,6 +199,7 @@ void prepareRuntimeOrder( [&available_input](Val* val) { return available_input.count(val); }); if (ready_to_run) { + std::cout << "Running group: " << toString(group) << std::endl; runtime_workspace.group_run_order.push_back(group); const auto& group_outputs = group->outputs(); @@ -203,6 +211,23 @@ void prepareRuntimeOrder( one_ran = true; } } + + for (const size_t group_i : + c10::irange(segmented_fusion->groups().size())) { + std::cout << "Checking: " << toString(segmented_fusion->groups()[group_i]) + << std::endl; + auto& group = segmented_fusion->groups()[group_i]; + if (group_ran[group_i]) { + continue; + } + const auto& group_inputs = group->inputs(); + for (auto group_inp : group_inputs) { + if (!available_input.count(group_inp)) { + std::cout << "Not available: " << group_inp->toString() << std::endl; + } + } + } + NVF_ERROR( one_ran, "Couldn't run all groups, something must have gone wrong in segmentation."); From 6f7b52e357cc17d15d7cae3489d577b31172ceee Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 23 Feb 2025 19:04:48 -0800 Subject: [PATCH 6/6] Having segmentation issues, trying to resolve them. --- csrc/fusion_segmenter.cpp | 267 +++++++++++++++++++++++++++++--------- csrc/fusion_segmenter.h | 4 + 2 files changed, 208 insertions(+), 63 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index bd23cbfd29d..1b793c69567 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -4018,6 +4018,44 @@ SchedulerRuntimeInfo& SegmentCandidateFinder::runtimeInfo() { return *runtime_info_; } +SegmentGroup* SegmentCandidateFinder::initializeExprGroup(Expr* expr) { + SegmentedGroup* expr_group = nullptr; + if(expr2group.count(expr)) { + expr_group = expr2group.at(expr); + } else { + expr_group = segmented_fusion_->newGroup(expr); + expr2group.insert(std::make_pair(expr, expr_group)); + } + + for (auto inp : expr->inputs()) { + if (input2group_.count(inp)) { + expr_group->input_vals.push_back(inp); + auto aux_group = input2group_.at(inp); + auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp); + expr_group->producer_edges.push_back(new_edge); + aux_group->consumer_edges.push_back(new_edge); + continue; + } + + // Could be something like a constant scalar, definition is nullptr, but + // isn't an "input" to the fusion. At least not one provided by an + // external source. + if (inp->definition() == nullptr) { + continue; + } + + auto def_group = expr2group.at(inp->definition()); + auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp); + expr_group->producer_edges.push_back(new_edge); + def_group->consumer_edges.push_back(new_edge); + } + for (auto out : expr->outputs()) { + if (out->isFusionOutput()) { + expr_group->output_vals.push_back(out); + } + } +} + void SegmentCandidateFinder::buildInitialSegments() { groups().clear(); edges().clear(); @@ -4026,17 +4064,20 @@ void SegmentCandidateFinder::buildInitialSegments() { // Need this for initialization of the DAG that is process std::unordered_map expr2group; + // TODO(wujingyue): remove singleton groups that are forwarded. They are + // useless and cause duplication. + forwardInputs(); + // Initialize DAG, convert each expr to a segment group auto exprs = completeFusion()->exprs(); for (auto expr : exprs) { + if(excluded_inp_unary_exprs_.has(expr)) { + continue; + } auto new_group = segmented_fusion_->newGroup(expr); expr2group.insert(std::make_pair(expr, new_group)); } - // TODO(wujingyue): remove singleton groups that are forwarded. They are - // useless and cause duplication. - forwardInputs(); - // Create edges between the Exprs. Mark inputs and outputs of the fusion. for (auto expr : exprs) { if (excluded_inp_unary_exprs_.has(expr)) { @@ -4045,7 +4086,7 @@ void SegmentCandidateFinder::buildInitialSegments() { SegmentedGroup* expr_group = expr2group.at(expr); for (auto inp : expr->inputs()) { - if (isFusionInput(inp)) { + if (input2group_.count(inp)) { expr_group->input_vals.push_back(inp); auto aux_group = input2group_.at(inp); auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp); @@ -4111,16 +4152,108 @@ void SegmentCandidateFinder::trySetUpMerge( } void SegmentCandidateFinder::resolveForwardedInputs() { - for (Val* forwarded_input : forwarded_fusion_inputs_) { - if (forwarded_input->isFusionInput()) { - // Nothing to resolve. - continue; - } - resolveForwardedInput(forwarded_input); - // aux_group will be removed from segmented_fusion_ by - // cleanupForwardedInputs. - } + for(auto [forward_val, inp] : forward_val_to_input_) { + auto inp_group = input2group_.at(forwarded_val); + input2group_.erase(forwarded_val); + eraseGroups({inp_group}); + + auto new_group = segmented_fusion_->newFusionInputGroup(); + input2group_.insert({inp, new_group}); + + std::vector groups_to_resolve; + for(auto group : groups()) { + if(std::find(group->input_vals.begin(), group->input_vals.end(), forward_val) + != group->input_vals.end()) { + groups_to_resolve.push_back(group); + } + } + auto exprs = DependencyCheck::getAllExprsBetween( + {inp}, + {forward_val} + ); + + SegmentedGroup* new_group = nullptr; + new_group->input_vals.push_back(inp); + new_group->exprs_.insert(new_group->exprs_.begin(), exprs.begin(), exprs.end()); + + bool can_merge = true; + for(auto group : groups_to_resolve) { + group->input_vals.erase( + std::remove(group->input_vals.begin(), group->input_vals.end(), forward_val), + group->input_vals.end() + ); + + auto new_edge = segmented_fusion_->newEdge(new_group, group, forward_val); + new_group->consumer_edges.push_back(new_edge); + group->producer_edges.push_back(new_edge); + can_merge = codeGenSupportedMerge(new_group, group); + new_group->consumer_edges.pop_back(); + group->producer_edges.pop_back(); + edges().erase(std::find(edges().begin(), edges().end(), new_edge)); + } + if(!can_merge) { + new_groups.pop_front(); + break; + } + if(can_merge) { + for(auto group : groups_to_resolve) { + if(new_group == nullptr) { + new_group = segmented_fusion_->newGroup(); + new_group->input_vals.push_back(inp); + new_group->output_vals.push_back(forward_val); + new_group->exprs_.insert(new_group->exprs_.begin(), exprs.begin(), exprs.end()); + } + auto new_edge = segmented_fusion_->newEdge(new_group, group, forward_val); + new_group->consumer_edges.push_back(new_edge); + group->producer_edges.push_back(new_edge); + can_merge = codeGenSupportedMerge(new_group, group); + new_group->consumer_edges.pop_back(); + group->producer_edges.pop_back(); + + to_merge_.push_back(new_group); + to_merge_.push_back(group); + mergeNodes(); + new_group = nullptr; + } + } else { + // TODO: handle the case where can_merge is false + } + } + // for(auto group : groups()) { + // for(auto forward_val : group->input_vals) { + // if(!forward_val_to_input_.count(forward_val)) { + // continue; + // } + // Val* inp = forward_val_to_input_[forward_val]; + // group->input_vals.push_back(inp); + // group->input_vals.erase( + // std::remove(group->input_vals.begin(), group->input_vals.end(), forward_val), + // group->input_vals.end() + // ); + // auto exprs = DependencyCheck::getAllExprsBetween( + // {inp}, + // {forward_val} + // ); + // group->exprs_.insert( + // group->exprs_.begin(), + // exprs.begin(), + // exprs.end() + // ); + // } + // } + + // for (Val* forwarded_input : forwarded_fusion_inputs_) { + // std::cout<<"Forwarded input: "<toString()<isFusionInput()) { + // // Nothing to resolve. + // continue; + // } + + // resolveForwardedInput(forwarded_input); + // // aux_group will be removed from segmented_fusion_ by + // // cleanupForwardedInputs. + // } } void SegmentCandidateFinder::findSegments() { @@ -4206,7 +4339,21 @@ void SegmentCandidateFinder::findSegments() { // Resolve all the input expressions needed in each group resolveForwardedInputs(); - + std::cout<<"After resolveForwardedInputs"<input_vals) { + std::cout<<" Input: "<toString()<output_vals) { + std::cout<<" Output: "<toString()<exprs()) { + std::cout<<" Expr: "<toString()<input_vals = VectorOfUniqueEntries(group->input_vals).vector(); } - // NVF_THROW("TMP"); // Do sanity check on the final graph. segmented_fusion_->validate(/*require_disjoint=*/false); @@ -4489,59 +4635,40 @@ void SegmentCandidateFinder::forwardInputs() { // "Terminating" outputs from the excluded input unary exprs, these will be // treated as complete fusion inputs. - VectorOfUniqueEntries forwarded_inputs; { - std::deque to_visit; - for (Val* inp : completeFusion()->inputs()) { - if (UnaryOp* unary_use = shouldForward(inp)) { - to_visit.push_back(unary_use); + for(Val* inp : completeFusion()->inputs()) { + if(!shouldForward(inp)) { + continue; } - } - - while (!to_visit.empty()) { - UnaryOp* uop = to_visit.front(); - to_visit.pop_front(); - - if (UnaryOp* unary_use = shouldForward(uop->out())) { - to_visit.push_back(unary_use); - } else { - // We cannot extend the chain of unary ops, so we finalize this chain by - // saving its output as a forwarded input. - forwarded_inputs.pushBack(uop->out()); + Val* forward_val = inp; + while(shouldForward(forward_val)) { + auto next_val = forward_val->uses().front()->outputs()[0]; + if(next_val->isFusionOutput()) { + break; + } + excluded_inp_unary_exprs_.pushBack(forward_val->uses().front()->as()); + forward_val = next_val; + } + if(!forward_val->isFusionOutput()) { + input_to_forward_val_[inp] = forward_val; + forward_val_to_input_[forward_val] = inp; } - // Either way, `uop` is excluded from merging until - // `resolveForwardedInput` adds it back to one of the segments. - excluded_inp_unary_exprs_.pushBack(uop); } } + + for(auto [inp, forward_val] : input_to_forward_val_) { + forwarded_fusion_inputs_.push_back(inp); + forwarded_fusion_inputs_.push_back(forward_val); + auto new_group = segmented_fusion_->newFusionInputGroup(); + input2group_.insert({forward_val, new_group}); + } - auto excluded_fusion_inputs = IterVisitor::getInputsTo( - {forwarded_inputs.begin(), forwarded_inputs.end()}); - - // List of vals to treat as complete fusion inputs for segmentation - forwarded_fusion_inputs_ = completeFusion()->inputs(); - - forwarded_fusion_inputs_.erase( - std::remove_if( - forwarded_fusion_inputs_.begin(), - forwarded_fusion_inputs_.end(), - [&excluded_fusion_inputs](Val* inp) { - return std::find( - excluded_fusion_inputs.begin(), - excluded_fusion_inputs.end(), - inp) != excluded_fusion_inputs.end(); - }), - forwarded_fusion_inputs_.end()); - - forwarded_fusion_inputs_.insert( - forwarded_fusion_inputs_.end(), - forwarded_inputs.begin(), - forwarded_inputs.end()); - - // Insert auxiliary groups to use group dependency on inputs as well - for (auto input : forwarded_fusion_inputs_) { + for(Val* inp : completeFusion()->inputs()) { + if(input_to_forward_val_.count(inp)) { + continue; + } auto new_group = segmented_fusion_->newFusionInputGroup(); - input2group_.insert({input, new_group}); + input2group_.insert({inp, new_group}); } } @@ -4658,7 +4785,21 @@ void SegmentCandidateFinder::resolveForwardedInput(Val* forwarded_input) { NVF_ERROR(to_merge_.empty()); to_merge_.push_back(input_group); to_merge_.push_back(consumer); - mergeNodes(); + auto merged_group = mergeNodes(); + std::cout<<"--------------------------------"<input_vals) { + std::cout<exprs()) { + std::cout<output_vals) { + std::cout< forwarded_fusion_inputs_; + std::unordered_map input_to_forward_val_; + std::unordered_map forward_val_to_input_; //! Keep track of complete fusion input use std::unordered_map input2group_;