Skip to content

Commit

Permalink
[HLO Componentization] Create pass sub-component
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674577592
  • Loading branch information
sdasgup3 authored and Google-ML-Automation committed Sep 14, 2024
1 parent 5a07f58 commit 4990bb3
Show file tree
Hide file tree
Showing 25 changed files with 703 additions and 493 deletions.
95 changes: 95 additions & 0 deletions xla/hlo/pass/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Description:
# XLA’s HLO Intermediate Representation implementation.

load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load("//xla:xla.bzl", "xla_cc_test")
load("//xla/tsl:tsl.bzl", "internal_visibility")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = internal_visibility([":friends"]),
licenses = ["notice"],
)

package_group(
name = "friends",
includes = [
"//xla:friends",
],
)

cc_library(
name = "hlo_pass",
hdrs = [
"hlo_pass_fix.h",
"hlo_pass_interface.h",
],
deps = [
"//xla:status_macros",
"//xla:types",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)

cc_library(
name = "hlo_pass_pipeline",
srcs = [
"hlo_pass_pipeline.cc",
],
hdrs = [
"hlo_pass_pipeline.h",
],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":hlo_pass",
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/service:compilation_stats",
"//xla/service:dump",
"//xla/service:hlo_graph_dumper",
"//xla/service:hlo_proto_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:status",
"@tsl//tsl/profiler/lib:scoped_annotation",
],
)

xla_cc_test(
name = "hlo_pass_pipeline_test",
srcs = ["hlo_pass_pipeline_test.cc"],
deps = [
":hlo_pass",
":hlo_pass_pipeline",
"//xla:test_helpers",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/service:hlo_parser",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)
58 changes: 58 additions & 0 deletions xla/hlo/pass/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# XLA HLO Pass Framework

This folder provides the core components of the XLA HLO pass framework, which is
responsible for optimizing and transforming HLO computations within XLA
compiler.

## Key APIs and Files

### `hlo_pass_interface.h`

Defines the foundational classes for HLO passes:

* `HloPassInterface`: Abstract base class for all HLO passes.
* `HloModulePass`: Subclass for passes that operate on individual HloModules.
* `HloModuleGroupPass`: Subclass for passes that operate on HloModuleGroups
(collections of modules).

Provides core methods like `Run`, `RunOnModuleGroup`, and
`RunOnChangedComputations` that passes must implement to perform their
transformations.

### `hlo_pass_fix.h`

Introduces the `HloPassFix` template class. Allows running an HLO pass
repeatedly until a fixed point is reached (no further changes occur in the HLO).
Useful for passes that may trigger further optimizations when applied
iteratively.

### `hlo_pass_pipeline.h`

Defines the `HloPassPipeline` class. Organizes a sequence of HLO passes into a
pipeline for sequential execution. Provides methods to add passes (`AddPass`)
and invariant checkers (`AddInvariantChecker`) to the pipeline. `Run` method
executes the entire pipeline on an HloModule or HloModuleGroup.

## Example Usage

```C++
// Create a pipeline
HloPassPipeline pipeline("my_pipeline");

// Add passes to the pipeline
pipeline.AddPass<SomeOptimizationPass>(/* pass arguments */);
pipeline.AddPass<HloPassFix<AnotherOptimizationPass>>(/* pass arguments */);

// Run the pipeline on an HloModule
HloModule module(/* ... */);
auto status = pipeline.Run(&module);
```
## Important Considerations
When creating custom HLO passes, inherit from either `HloModulePass` or
`HloModuleGroupPass` depending on the scope of your transformation. Implement
the required virtual methods (e.g., `Run`) to define the pass's behavior.
Utilize `HloPassFix` when your pass's transformations may trigger further
optimizations upon repeated application. Construct `HloPassPipelines` to
orchestrate the execution of multiple passes in a defined sequence.
136 changes: 136 additions & 0 deletions xla/hlo/pass/hlo_pass_fix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/* Copyright 2017 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_HLO_PASS_HLO_PASS_FIX_H_
#define XLA_HLO_PASS_HLO_PASS_FIX_H_

#include <algorithm>
#include <type_traits>

#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/status_macros.h"
#include "xla/types.h"

namespace xla {

// Do an HLO pass to a fix point.
template <typename Pass, int kIterationLimit = 25>
class HloPassFix : public Pass {
public:
static_assert(std::is_base_of<HloPassInterface, Pass>::value,
"Pass must be a subclass of HloPassInterface");
using RunState = HloPassInterface::RunState;
template <typename... Args>
explicit HloPassFix(Args&&... args) : Pass(args...) {}

absl::Status RunOnChangedComputations(
HloModule* module, RunState* outer_run_state,
const absl::flat_hash_set<absl::string_view>& execution_threads)
override {
RunState run_state;
run_state.changed_last_iteration = outer_run_state->changed_last_iteration;
TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads));
outer_run_state->changed_this_iteration.insert(run_state.changed.begin(),
run_state.changed.end());
return absl::OkStatus();
}

using HloPassInterface::Run;
absl::StatusOr<bool> Run(HloModule* module,
const absl::flat_hash_set<absl::string_view>&
execution_threads) override {
RunState run_state(module);
TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads));
return !run_state.changed.empty();
}

using HloPassInterface::RunOnModuleGroup;
absl::StatusOr<bool> RunOnModuleGroup(
HloModuleGroup* module_group,
const absl::flat_hash_set<absl::string_view>& execution_threads)
override {
bool changed = false;
bool changed_this_iteration = true;
int64_t iteration_count = 0;
VLOG(3) << "Running HloPassFix.";
while (changed_this_iteration) {
TF_ASSIGN_OR_RETURN(
changed_this_iteration,
Pass::RunOnModuleGroup(module_group, execution_threads));
changed |= changed_this_iteration;
VLOG(3) << "changed_this_iteration: " << changed_this_iteration;
++iteration_count;
if (iteration_count == kIterationLimit) {
VLOG(1) << "Unexpectedly high number of iterations in HLO passes, "
"exiting fixed point loop.";
// Return false in case this is fixed point is nested.
return false;
}
}
return changed;
}

private:
absl::Status RunToFixPoint(
HloModule* module, RunState* run_state,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
VLOG(3) << "Running HloPassFix on " << Pass::name();
while (!run_state->changed_last_iteration.empty()) {
TF_RETURN_IF_ERROR(
RunOnChangedComputationsOnce(module, run_state, execution_threads));
VLOG(3) << Pass::name() << " iteration " << run_state->iteration
<< " changed_this_iteration: "
<< !run_state->changed_last_iteration.empty();
run_state->IncrementIteration();
if (run_state->iteration == kIterationLimit) {
VLOG(1) << "Unexpectedly high number of iterations in HLO passes '"
<< Pass::name() << "' for module '" << module->name()
<< "'. Exiting fixed point loop.";
// Clear changed and abort in case this is fixed point is nested.
run_state->changed.clear();
break;
}
}
return absl::OkStatus();
}

absl::Status RunOnChangedComputationsOnce(
HloModule* module, RunState* run_state,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
// If Pass overrides RunOnChangedComputations, just forward to it.
if (!std::is_same<decltype(&HloPassInterface::RunOnChangedComputations),
decltype(&Pass::RunOnChangedComputations)>::value) {
return Pass::RunOnChangedComputations(module, run_state,
execution_threads);
}
// If Pass does not override the default
// HloPassInterface::RunOnChangedComputations that calls into
// HloPassFix<Pass>::Run, avoid infinite recursion.
TF_ASSIGN_OR_RETURN(bool changed, Pass::Run(module, execution_threads));
if (changed) {
auto computations = module->computations(execution_threads);
run_state->changed_this_iteration.insert(computations.begin(),
computations.end());
}
return absl::OkStatus();
}
};

} // namespace xla

#endif // XLA_HLO_PASS_HLO_PASS_FIX_H_
Loading

0 comments on commit 4990bb3

Please sign in to comment.