Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scalar outputs from fusions #3947

Draft
wants to merge 6 commits into
base: polymorphic_outs_step_8
Choose a base branch
from

Conversation

csarofeen
Copy link
Collaborator

No description provided.

@csarofeen csarofeen marked this pull request as draft February 23, 2025 01:59
Copy link

github-actions bot commented Feb 23, 2025

Review updated until commit 6f7b52e

Description

  • Support scalar outputs from fusions in Fusion::addOutputInternal

  • Add logic to simplify constant scalars in SegmentCandidateFinder::segment

  • Introduce initializeExprGroup method in SegmentCandidateFinder

  • Modify buildInitialSegments to handle excluded input unary expressions

  • Update resolveForwardedInputs to handle forwarded inputs more effectively

  • Remove unused resolveScalarsInGroup and removeScalarEdges methods

  • Enhance ExprEvalExecutor to support scalar outputs

  • Update fusion_cache_utils.cpp for better debugging and logging

  • Add isNoOp function in expr_eval_sched.cpp to identify no-op expressions

  • Update registry.cpp to reject non-ExprEval schedulers with scalar outputs

  • Modify tests to accommodate scalar processing and validation changes


Changes walkthrough 📝

Relevant files
Enhancement
8 files
fusion.cpp
Allow non-TensorView outputs in `addOutputInternal`           
+3/-5     
fusion_segmenter.cpp
Add logic to simplify constant scalars and initialize expression
groups
+270/-283
allocations.cpp
Support scalar outputs in `inferOutputSizes`                         
+26/-9   
executor.cpp
Enhance `ExprEvalExecutor` to handle scalar outputs           
+5/-4     
fusion_cache_utils.cpp
Add debugging and logging for fusion groups                           
+25/-0   
expr_eval_sched.cpp
Add `isNoOp` function to identify no-op expressions           
+45/-18 
registry.cpp
Reject non-ExprEval schedulers with scalar outputs             
+12/-0   
fusion_segmenter.h
Update `SegmentCandidateFinder` interface                               
+5/-9     
Tests
3 files
test_gpu2.cpp
Update test to accommodate scalar processing                         
+6/-4     
validator.cpp
Update validation to handle scalar outputs                             
+52/-17 
validator.h
Update function signature for `testValidate`                         
+1/-1     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The initializeExprGroup function does not return a value, which seems to be an oversight. It should return the expr_group created or modified.

SegmentGroup* SegmentCandidateFinder::initializeExprGroup(Expr* expr) {
  SegmentedGroup* expr_group = nullptr;
  if(expr2group.count(expr)) {
    expr_group = expr2group.at(expr);
  } else {
    expr_group = segmented_fusion_->newGroup(expr);
    expr2group.insert(std::make_pair(expr, expr_group));
  }

  for (auto inp : expr->inputs()) {
    if (input2group_.count(inp)) {
      expr_group->input_vals.push_back(inp);
      auto aux_group = input2group_.at(inp);
      auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp);
      expr_group->producer_edges.push_back(new_edge);
      aux_group->consumer_edges.push_back(new_edge);
      continue;
    }

    // Could be something like a constant scalar, definition is nullptr, but
    // isn't an "input" to the fusion. At least not one provided by an
    // external source.
    if (inp->definition() == nullptr) {
      continue;
    }

    auto def_group = expr2group.at(inp->definition());
    auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp);
    expr_group->producer_edges.push_back(new_edge);
    def_group->consumer_edges.push_back(new_edge);
  }
  for (auto out : expr->outputs()) {
    if (out->isFusionOutput()) {
      expr_group->output_vals.push_back(out);
    }
  }
}
Code Quality

The inferOutputSizes function includes debug print statements (std::cout) that should be removed or replaced with logging mechanisms before merging.

for (Val* output : fusion->outputs()) {
  if (output->isA<TensorView>()) {
    auto output_tv = output->as<TensorView>();
    const auto& [sizes, strides] = inferShapeOfOutput(output_tv, expr_eval);
    const auto dtype = (output_tv->dtype() == DataType::Index)
        ? data_type_to_aten(arg_index_type)
        : data_type_to_aten(output_tv->dtype());
    output_tensor_proxies.pushTensorProxy(sizes, strides, dtype);
  } else if (output->isScalar()) {
    switch (std::get<PrimDataType>(output->dtype().type)) {
      case DataType::Int:
      case DataType::Int32:
        output_tensor_proxies.push(PolymorphicValue(0LL));
        break;
      case DataType::Double:
      case DataType::Float:
        output_tensor_proxies.push(PolymorphicValue(0.0));
        break;
      case DataType::Bool:
        output_tensor_proxies.push(PolymorphicValue(false));
        break;
      default:
        NVF_ERROR("Output type not supported: ", output->toString());
    }
  } else {
    NVF_ERROR("Output type not supported: ", output->toString());
  }
Code Quality

The canSchedule function includes a debug print statement (scheduler_debug_utils::canScheduleMessage) that should be removed or replaced with logging mechanisms before merging.

scheduler_debug_utils::canScheduleMessage(
    "***Rejected*** scheduler ",
    scheduler_type,
    " cannot accept scalar outputs");
return false;

@csarofeen csarofeen changed the base branch from main to polymorphic_outs_step_8 February 23, 2025 14:41
group->producer_edges.push_back(new_edge);
can_merge = codeGenSupportedMerge(new_group, group);
new_group->consumer_edges.pop_back();
group->producer_edges.pop_back();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

new_group = nullptr;
}
} else {
// TODO: handle the case where can_merge is false
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continue

@@ -3969,6 +4018,44 @@ SchedulerRuntimeInfo& SegmentCandidateFinder::runtimeInfo() {
return *runtime_info_;
}

SegmentGroup* SegmentCandidateFinder::initializeExprGroup(Expr* expr) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove, not helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant