diff --git a/xla/hlo/pass/BUILD b/xla/hlo/pass/BUILD new file mode 100644 index 0000000000000..4932aa3b3a832 --- /dev/null +++ b/xla/hlo/pass/BUILD @@ -0,0 +1,95 @@ +# Description: +# XLA’s HLO Intermediate Representation implementation. + +load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "@tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_pass", + hdrs = [ + "hlo_pass_fix.h", + "hlo_pass_interface.h", + ], + deps = [ + "//xla:status_macros", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "hlo_pass_pipeline", + srcs = [ + "hlo_pass_pipeline.cc", + ], + hdrs = [ + "hlo_pass_pipeline.h", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + ":hlo_pass", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:compilation_stats", + "//xla/service:dump", + "//xla/service:hlo_graph_dumper", + "//xla/service:hlo_proto_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", + "@tsl//tsl/profiler/lib:scoped_annotation", + ], +) + +xla_cc_test( + name = "hlo_pass_pipeline_test", + srcs = ["hlo_pass_pipeline_test.cc"], + deps = [ + ":hlo_pass", + ":hlo_pass_pipeline", + "//xla:test_helpers", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "//xla/service:hlo_parser", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + ], +) diff --git a/xla/hlo/pass/README.md b/xla/hlo/pass/README.md new file mode 100644 index 0000000000000..265156ffb9b24 --- /dev/null +++ b/xla/hlo/pass/README.md @@ -0,0 +1,58 @@ +# XLA HLO Pass Framework + +This folder provides the core components of the XLA HLO pass framework, which is +responsible for optimizing and transforming HLO computations within XLA +compiler. + +## Key APIs and Files + +### `hlo_pass_interface.h` + +Defines the foundational classes for HLO passes: + +* `HloPassInterface`: Abstract base class for all HLO passes. +* `HloModulePass`: Subclass for passes that operate on individual HloModules. +* `HloModuleGroupPass`: Subclass for passes that operate on HloModuleGroups +(collections of modules). + +Provides core methods like `Run`, `RunOnModuleGroup`, and +`RunOnChangedComputations` that passes must implement to perform their +transformations. + +### `hlo_pass_fix.h` + +Introduces the `HloPassFix` template class. Allows running an HLO pass +repeatedly until a fixed point is reached (no further changes occur in the HLO). +Useful for passes that may trigger further optimizations when applied +iteratively. + +### `hlo_pass_pipeline.h` + +Defines the `HloPassPipeline` class. Organizes a sequence of HLO passes into a +pipeline for sequential execution. Provides methods to add passes (`AddPass`) +and invariant checkers (`AddInvariantChecker`) to the pipeline. `Run` method +executes the entire pipeline on an HloModule or HloModuleGroup. + +## Example Usage + +```C++ +// Create a pipeline +HloPassPipeline pipeline("my_pipeline"); + +// Add passes to the pipeline +pipeline.AddPass(/* pass arguments */); +pipeline.AddPass>(/* pass arguments */); + +// Run the pipeline on an HloModule +HloModule module(/* ... */); +auto status = pipeline.Run(&module); +``` + +## Important Considerations + +When creating custom HLO passes, inherit from either `HloModulePass` or +`HloModuleGroupPass` depending on the scope of your transformation. Implement +the required virtual methods (e.g., `Run`) to define the pass's behavior. +Utilize `HloPassFix` when your pass's transformations may trigger further +optimizations upon repeated application. Construct `HloPassPipelines` to +orchestrate the execution of multiple passes in a defined sequence. diff --git a/xla/hlo/pass/hlo_pass_fix.h b/xla/hlo/pass/hlo_pass_fix.h new file mode 100644 index 0000000000000..dd92eadcc6d56 --- /dev/null +++ b/xla/hlo/pass/hlo_pass_fix.h @@ -0,0 +1,136 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_PASS_HLO_PASS_FIX_H_ +#define XLA_HLO_PASS_HLO_PASS_FIX_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/status_macros.h" +#include "xla/types.h" + +namespace xla { + +// Do an HLO pass to a fix point. +template +class HloPassFix : public Pass { + public: + static_assert(std::is_base_of::value, + "Pass must be a subclass of HloPassInterface"); + using RunState = HloPassInterface::RunState; + template + explicit HloPassFix(Args&&... args) : Pass(args...) {} + + absl::Status RunOnChangedComputations( + HloModule* module, RunState* outer_run_state, + const absl::flat_hash_set& execution_threads) + override { + RunState run_state; + run_state.changed_last_iteration = outer_run_state->changed_last_iteration; + TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads)); + outer_run_state->changed_this_iteration.insert(run_state.changed.begin(), + run_state.changed.end()); + return absl::OkStatus(); + } + + using HloPassInterface::Run; + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { + RunState run_state(module); + TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads)); + return !run_state.changed.empty(); + } + + using HloPassInterface::RunOnModuleGroup; + absl::StatusOr RunOnModuleGroup( + HloModuleGroup* module_group, + const absl::flat_hash_set& execution_threads) + override { + bool changed = false; + bool changed_this_iteration = true; + int64_t iteration_count = 0; + VLOG(3) << "Running HloPassFix."; + while (changed_this_iteration) { + TF_ASSIGN_OR_RETURN( + changed_this_iteration, + Pass::RunOnModuleGroup(module_group, execution_threads)); + changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; + ++iteration_count; + if (iteration_count == kIterationLimit) { + VLOG(1) << "Unexpectedly high number of iterations in HLO passes, " + "exiting fixed point loop."; + // Return false in case this is fixed point is nested. + return false; + } + } + return changed; + } + + private: + absl::Status RunToFixPoint( + HloModule* module, RunState* run_state, + const absl::flat_hash_set& execution_threads) { + VLOG(3) << "Running HloPassFix on " << Pass::name(); + while (!run_state->changed_last_iteration.empty()) { + TF_RETURN_IF_ERROR( + RunOnChangedComputationsOnce(module, run_state, execution_threads)); + VLOG(3) << Pass::name() << " iteration " << run_state->iteration + << " changed_this_iteration: " + << !run_state->changed_last_iteration.empty(); + run_state->IncrementIteration(); + if (run_state->iteration == kIterationLimit) { + VLOG(1) << "Unexpectedly high number of iterations in HLO passes '" + << Pass::name() << "' for module '" << module->name() + << "'. Exiting fixed point loop."; + // Clear changed and abort in case this is fixed point is nested. + run_state->changed.clear(); + break; + } + } + return absl::OkStatus(); + } + + absl::Status RunOnChangedComputationsOnce( + HloModule* module, RunState* run_state, + const absl::flat_hash_set& execution_threads) { + // If Pass overrides RunOnChangedComputations, just forward to it. + if (!std::is_same::value) { + return Pass::RunOnChangedComputations(module, run_state, + execution_threads); + } + // If Pass does not override the default + // HloPassInterface::RunOnChangedComputations that calls into + // HloPassFix::Run, avoid infinite recursion. + TF_ASSIGN_OR_RETURN(bool changed, Pass::Run(module, execution_threads)); + if (changed) { + auto computations = module->computations(execution_threads); + run_state->changed_this_iteration.insert(computations.begin(), + computations.end()); + } + return absl::OkStatus(); + } +}; + +} // namespace xla + +#endif // XLA_HLO_PASS_HLO_PASS_FIX_H_ diff --git a/xla/hlo/pass/hlo_pass_interface.h b/xla/hlo/pass/hlo_pass_interface.h new file mode 100644 index 0000000000000..d09f65e271451 --- /dev/null +++ b/xla/hlo/pass/hlo_pass_interface.h @@ -0,0 +1,184 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_PASS_HLO_PASS_INTERFACE_H_ +#define XLA_HLO_PASS_HLO_PASS_INTERFACE_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/status_macros.h" +#include "xla/types.h" + +namespace xla { + +// Base class for HLO passes. These are used with the HloPassPipeline to +// organize a sequence of passes. An HLO pass should not extend this class +// directly; it should extend HloModulePass or HloModuleGroupPass. +class HloPassInterface { + public: + // Struct that holds states of pass runs across multiple iterations. + struct RunState { + // The current iteration number. + int iteration = 0; + // Set of all changed computations from all pass runs using this state. + absl::flat_hash_set changed; + // Set of changed computation from previous iteration. + absl::flat_hash_set changed_last_iteration; + // Set of changed computation from current iteration. + absl::flat_hash_set changed_this_iteration; + + RunState() = default; + explicit RunState(HloModule* module) + : changed_last_iteration(module->computations().begin(), + module->computations().end()) {} + + // Transition to the next iteration. + // + // Depending on the pass implmentation, one iteration includes all the work + // done between two IncrementIteration calls, there can be arbitrary number + // of passes that ran arbitrary times with this state. + void IncrementIteration() { + using std::swap; + changed.insert(changed_this_iteration.begin(), + changed_this_iteration.end()); + swap(changed_last_iteration, changed_this_iteration); + changed_this_iteration.clear(); + ++iteration; + } + }; + virtual ~HloPassInterface() = default; + virtual absl::string_view name() const = 0; + + // Run the pass on the given HLO module with specified execution_threads. + // Empty execution_threads list means all execution_threads are included. + // Returns whether it modified the module. Note that due to C++ inheritance + // hides overloaded function, Run(HloModule* module) is not a member function + // of a subclass unless it's explicitly brought to the subclass besides + // implementing the virtual version, for instance, + // + // class MyNewPass : public HloModulePass { + // public: + // MyNewPass(); + // absl::string_view name() const override { return "my-new-pass"; } + // + // using HloPassInterface::Run; + // absl::StatusOr Run( + // HloModule* module, + // const absl::flat_hash_set& execution_threads) + // override; + // }; + // + absl::StatusOr Run(HloModule* module) { + return Run(module, /*execution_threads=*/{}); + } + virtual absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) = 0; + + // Run the pass on computation on changed computations from last iteration in + // given HLO module for specified execution_threads, with caller provided + // RunState which holds the state information across multiple iterations. + // + // NOTE: This is a temporary default implementation that conservatively treats + // all computations as changed. Eventually all passes should override this + // method instead of Run() and Run() will call into this method instead. + virtual absl::Status RunOnChangedComputations( + HloModule* module, RunState* run_state, + const absl::flat_hash_set& execution_threads) { + TF_ASSIGN_OR_RETURN(bool changed, Run(module, execution_threads)); + if (changed) { + auto computations = module->computations(execution_threads); + run_state->changed_this_iteration.insert(computations.begin(), + computations.end()); + } + return absl::OkStatus(); + } + + // Run the pass on the given HLO module group for specified + // `execution_threads`. Empty `execution_threads` list means all execution + // threads are included. Returns whether it modified the module group. + // Ideally, the module group variant would be named "Run" as well, but C++ + // does not handle overloaded virtual methods well. + // + // Note that due to C++ inheritance hides overloaded function, + // RunOnModuleGroup(HloModuleGroup* module_group) is not a member function of + // a subclass unless it's explicitly brought to the subclass besides + // implementing the virtual version, for instance, + // + // class MyNewPass : public HloModuleGroupPass { + // public: + // MyNewPass(); + // absl::string_view name() const override { return "my-new-pass"; } + // + // using HloPassInterface::RunOnModuleGroup; + // absl::StatusOr RunOnModuleGroup( + // HloModuleGroup* module_group, + // const absl::flat_hash_set& execution_threads) + // override; + // }; + // + absl::StatusOr RunOnModuleGroup(HloModuleGroup* module_group) { + return RunOnModuleGroup(module_group, /*execution_threads=*/{}); + } + virtual absl::StatusOr RunOnModuleGroup( + HloModuleGroup* module_group, + const absl::flat_hash_set& execution_threads) = 0; + + virtual bool IsPassPipeline() { return false; } +}; + +// Base class for passes which are module-scoped. +class HloModulePass : public HloPassInterface { + public: + // Runs the pass on a module group by iterating through each module in the + // group. + absl::StatusOr RunOnModuleGroup( + HloModuleGroup* module_group, + const absl::flat_hash_set& execution_threads) + override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + TF_ASSIGN_OR_RETURN(bool module_changed, Run(module, execution_threads)); + changed |= module_changed; + } + return changed; + }; + + // Update the layout of a Shape to one that is supported by a given backend. + // One can call this function after modifying the Shape in case that modifying + // the Shape requires changes to the layout for the given Backend. + // + // TODO(b/129084868): Make this Backend dependent instead of requiring + // deriving from the pass and overriding this function. + virtual void UpdateLayout(Shape* shape) {} +}; + +// Base class for passes which are module-group scoped. These passes cannot run +// on an HLO module. +class HloModuleGroupPass : public HloPassInterface { + public: + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { + return Internal("Module group pass cannot be run on a module"); + } +}; + +} // namespace xla + +#endif // XLA_HLO_PASS_HLO_PASS_INTERFACE_H_ diff --git a/xla/service/hlo_pass_pipeline.cc b/xla/hlo/pass/hlo_pass_pipeline.cc similarity index 99% rename from xla/service/hlo_pass_pipeline.cc rename to xla/hlo/pass/hlo_pass_pipeline.cc index 82021a8e4cbca..f6b315ca4c961 100644 --- a/xla/service/hlo_pass_pipeline.cc +++ b/xla/hlo/pass/hlo_pass_pipeline.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_pass_pipeline.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include #include diff --git a/xla/hlo/pass/hlo_pass_pipeline.h b/xla/hlo/pass/hlo_pass_pipeline.h new file mode 100644 index 0000000000000..9510787d0294c --- /dev/null +++ b/xla/hlo/pass/hlo_pass_pipeline.h @@ -0,0 +1,171 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_PASS_HLO_PASS_PIPELINE_H_ +#define XLA_HLO_PASS_HLO_PASS_PIPELINE_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/compilation_stats.h" +#include "xla/types.h" + +namespace xla { + +class PhaseOrderPipeline; + +// Pipeline of HLO passes. +class HloPassPipeline : public HloPassInterface { + public: + explicit HloPassPipeline(const std::string& name, + CompilationStats* compilation_stats = nullptr) + : name_(name), compilation_stats_(compilation_stats) { + if (compilation_stats == nullptr) { + empty_compilation_stats_ = CompilationStats::MakeNoopStats(); + compilation_stats_ = empty_compilation_stats_.get(); + } + } + absl::string_view name() const override { return name_; } + + // Add a pass to the pipeline. It should be called with the arguments for the + // pass constructor: + // + // pipeline.AddPass(constructor_arg1, constructor_arg2); + // + // Returns a reference to the added pass. + template + T& AddPass(Args&&... args) { + CHECK(!run_called_) << "AddPass cannot be called after Run"; + auto pass = new T(std::forward(args)...); + passes_.push_back(std::unique_ptr(pass)); + return *pass; + } + + // Add an invariant-checking pass to the pipeline. It will be run before and + // after each HLO pass. The invariant checking pass must not mutate the graph + // (it is required to always return "false" from its Run() method). + template + T& AddInvariantChecker(Args&&... args) { + CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run"; + auto pass = new T(std::forward(args)...); + invariant_checkers_.push_back(std::unique_ptr(pass)); + return *pass; + } + + // Add an invariant-checking pass to the pipeline on debug builds only. + template + void AddInvariantCheckerDebug(Args&&... args) { +#ifndef NDEBUG + AddInvariantChecker(std::forward(args)...); +#endif // NDEBUG + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + using HloPassInterface::RunOnModuleGroup; + absl::StatusOr RunOnModuleGroup( + HloModuleGroup* module_group, + const absl::flat_hash_set& execution_threads) override; + + bool IsPassPipeline() override { return true; } + + // Return size of passes_. + int PassesSize() { return passes_.size(); } + // Return reference to pass specified by index. + HloPassInterface& GetPass(int index) { return *passes_[index]; } + + private: + // Returns the set of passes which are enabled. DebugOptions can selectively + // disable passes via --xla_disable_hlo_passes flag. + std::vector GetEnabledPasses( + const DebugOptions& debug_options); + + // Maybe dumps the given module or module group depending on flag values + // contained in DebugOptions of module config. If it is dumped, saves the + // filenames of the dumps into module metadata. + void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name); + void MaybeDumpHloAndSaveFilenames(HloModule& module, + absl::string_view after_pass_name, + absl::string_view before_pass_name); + + // Runs the invariant checker on the given HLO for specified + // `execution_threads`. Empty `execution_threads` means all execution threads + // are included. HloT can be either HloModule or HloModuleGroup. + template + absl::Status RunInvariantCheckers(HloT* hlo, + absl::string_view after_pass_name) { + return RunInvariantCheckers(hlo, after_pass_name, /*execution_threads=*/{}); + } + template + absl::Status RunInvariantCheckers( + HloT* hlo, absl::string_view after_pass_name, + const absl::flat_hash_set& execution_threads); + + // Helper which runs the given pass on the given HLO. HloT can be either + // HloModule or HloModuleGroup. + template + absl::StatusOr RunPassesInternal( + HloT* hlo, const DebugOptions& debug_options, + const absl::flat_hash_set& execution_threads); + + // Helpers which run the given passes on the given HLO construct. Only + // computations with specified `execution_threads` are considered by the pass, + // empty thread list means all `execution_threads` are considered. These + // helpers enable templating of the core of the pipeline logic by providing + // HloModule and HloModuleGroup specific methods with the same name. + static absl::StatusOr RunHelper( + HloPassInterface* pass, HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module, execution_threads)); + module->Cleanup(); + return changed; + } + static absl::StatusOr RunHelper( + HloPassInterface* pass, HloModuleGroup* module_group, + const absl::flat_hash_set& execution_threads) { + TF_ASSIGN_OR_RETURN( + bool changed, pass->RunOnModuleGroup(module_group, execution_threads)); + module_group->Cleanup(); + return changed; + } + + const std::string name_; + std::vector> passes_; + std::vector> invariant_checkers_; + bool run_called_ = false; + + CompilationStats* compilation_stats_; + // Default stats instance for when one is not passed in the constructor. + // Use via compilation_stats_, not directly. + std::unique_ptr empty_compilation_stats_; + + // Allow PhaseOrderPipeline to modify private passes_ member in order to + // perform PhaseOrdering. + friend class ::xla::PhaseOrderPipeline; +}; + +} // namespace xla + +#endif // XLA_HLO_PASS_HLO_PASS_PIPELINE_H_ diff --git a/xla/service/hlo_pass_pipeline_test.cc b/xla/hlo/pass/hlo_pass_pipeline_test.cc similarity index 96% rename from xla/service/hlo_pass_pipeline_test.cc rename to xla/hlo/pass/hlo_pass_pipeline_test.cc index 502406bb54d1f..5ef86e33ef946 100644 --- a/xla/service/hlo_pass_pipeline_test.cc +++ b/xla/hlo/pass/hlo_pass_pipeline_test.cc @@ -13,15 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_pass_pipeline.h" - +#include "xla/hlo/pass/hlo_pass_pipeline.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/hlo_parser.h" +#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/BUILD b/xla/service/BUILD index f4bd54ad8b678..ed1ee8b007974 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -5481,6 +5481,8 @@ xla_cc_test( ], ) +# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass # +# instead. cc_library( name = "hlo_pass", hdrs = [ @@ -5492,54 +5494,26 @@ cc_library( "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", ], ) +# Deprecated, use +# //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass_pipeline instead. cc_library( name = "hlo_pass_pipeline", - srcs = [ - "hlo_pass_pipeline.cc", - ], - hdrs = [ - "hlo_pass_pipeline.h", - ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + hdrs = ["hlo_pass_pipeline.h"], deps = [ ":compilation_stats", - ":dump", - ":hlo_graph_dumper", ":hlo_pass", - ":hlo_proto_util", - "//xla:status_macros", "//xla:types", - "//xla:util", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/hlo/pass:hlo_pass_pipeline", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/profiler/lib:scoped_annotation", - ], -) - -xla_cc_test( - name = "hlo_pass_pipeline_test", - srcs = ["hlo_pass_pipeline_test.cc"], - deps = [ - ":hlo_parser", - ":hlo_pass_pipeline", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", ], ) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index da794174752a7..d4ee1eb952f51 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -242,6 +242,8 @@ cc_library( "//xla/backends/cpu/runtime:thunk", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/pass:hlo_pass_pipeline", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", "//xla/mlir_hlo:mhlo_passes", @@ -287,8 +289,6 @@ cc_library( "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_module_config", "//xla/service:hlo_ordering", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", "//xla/service:hlo_profile_printer_data_cc", "//xla/service:hlo_proto_cc", "//xla/service:hlo_proto_util", @@ -1548,7 +1548,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", + "//xla/hlo/pass:hlo_pass", "@tsl//tsl/platform:errors", ], ) @@ -1601,8 +1601,8 @@ cc_library( ":target_machine_features", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_pass", "//xla/service/llvm_ir:dynamic_update_slice_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1926,9 +1926,9 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", @@ -1956,8 +1956,8 @@ cc_library( "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", "@tsl//tsl/platform:platform_port", diff --git a/xla/service/cpu/conv_canonicalization.h b/xla/service/cpu/conv_canonicalization.h index 5bdecaf7b3ec7..a65e22781d13e 100644 --- a/xla/service/cpu/conv_canonicalization.h +++ b/xla/service/cpu/conv_canonicalization.h @@ -17,8 +17,8 @@ limitations under the License. #define XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/cpu/target_machine_features.h" -#include "xla/service/hlo_pass_interface.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 80f55bae864e4..ada4b7ba8dfae 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -87,6 +87,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/pass/hlo_pass_fix.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/literal.h" #include "xla/map_util.h" #include "xla/mlir_hlo/transforms/passes.h" @@ -148,8 +150,6 @@ limitations under the License. #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_ordering.h" -#include "xla/service/hlo_pass_fix.h" -#include "xla/service/hlo_pass_pipeline.h" #include "xla/service/hlo_profile_printer_data.pb.h" #include "xla/service/hlo_verifier.h" #include "xla/service/indexed_array_analysis.h" diff --git a/xla/service/cpu/onednn_contraction_rewriter.h b/xla/service/cpu/onednn_contraction_rewriter.h index 7864dae961386..503d8a8ee2563 100644 --- a/xla/service/cpu/onednn_contraction_rewriter.h +++ b/xla/service/cpu/onednn_contraction_rewriter.h @@ -23,7 +23,7 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "tsl/platform/threadpool.h" namespace xla { diff --git a/xla/service/cpu/onednn_ops_rewriter.h b/xla/service/cpu/onednn_ops_rewriter.h index 8e777d8889ad8..e56e013540b34 100644 --- a/xla/service/cpu/onednn_ops_rewriter.h +++ b/xla/service/cpu/onednn_ops_rewriter.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" +#include "xla/hlo/pass/hlo_pass_interface.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/onednn_rewriter.h b/xla/service/cpu/onednn_rewriter.h index 53fd5c0f977c4..95d2692ec6aa1 100644 --- a/xla/service/cpu/onednn_rewriter.h +++ b/xla/service/cpu/onednn_rewriter.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" +#include "xla/hlo/pass/hlo_pass_interface.h" namespace xla { namespace cpu { diff --git a/xla/service/cpu/parallel_task_assignment.h b/xla/service/cpu/parallel_task_assignment.h index e523323262b79..bee7bd6983a0d 100644 --- a/xla/service/cpu/parallel_task_assignment.h +++ b/xla/service/cpu/parallel_task_assignment.h @@ -26,9 +26,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/cpu/target_machine_features.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/service/hlo_pass_interface.h" #include "xla/util.h" namespace xla { diff --git a/xla/service/gpu/autotuning/BUILD b/xla/service/gpu/autotuning/BUILD index 0f287c6ec3095..cae9e277097f9 100644 --- a/xla/service/gpu/autotuning/BUILD +++ b/xla/service/gpu/autotuning/BUILD @@ -44,6 +44,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:algorithm_util", @@ -53,7 +54,6 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_graph_dumper", "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", "//xla/service:shaped_buffer", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", @@ -122,12 +122,12 @@ xla_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:call_inliner", "//xla/service:dump", "//xla/service:executable", "//xla/service:hlo_module_config", - "//xla/service:hlo_pass_pipeline", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:backend_configs_cc", @@ -175,8 +175,8 @@ cc_library( "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:cublas_cudnn", @@ -351,9 +351,9 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/service:executable", "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", "//xla/service:slow_operation_alarm", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", @@ -446,8 +446,8 @@ cc_library( "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/service:executable", - "//xla/service:hlo_pass", "//xla/service:shaped_buffer", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:ir_emission_utils", @@ -482,7 +482,7 @@ xla_test( ":custom_kernel_fusion_autotuner", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass_pipeline", + "//xla/hlo/pass:hlo_pass_pipeline", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_googletest//:gtest", diff --git a/xla/service/gpu/autotuning/conv_algorithm_picker.h b/xla/service/gpu/autotuning/conv_algorithm_picker.h index 881b1535a7ed0..3a35c2d6044f4 100644 --- a/xla/service/gpu/autotuning/conv_algorithm_picker.h +++ b/xla/service/gpu/autotuning/conv_algorithm_picker.h @@ -30,12 +30,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_conv_runner.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" diff --git a/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h b/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h index 07aad07aebd9b..7ea7f4b51beea 100644 --- a/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h +++ b/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h @@ -21,8 +21,8 @@ limitations under the License. #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/gpu/autotuning/autotuner_util.h" -#include "xla/service/hlo_pass_interface.h" #include "xla/xla.pb.h" namespace xla { diff --git a/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc b/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc index 03a332345a638..3a214ffeb56d5 100644 --- a/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc +++ b/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/service/gpu/autotuning/autotuner_util.h" -#include "xla/service/hlo_pass_pipeline.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/autotuning/gemm_algorithm_picker.h b/xla/service/gpu/autotuning/gemm_algorithm_picker.h index 237358388b16e..3ba3bde9430df 100644 --- a/xla/service/gpu/autotuning/gemm_algorithm_picker.h +++ b/xla/service/gpu/autotuning/gemm_algorithm_picker.h @@ -26,9 +26,9 @@ limitations under the License. #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_pass_interface.h" #include "xla/shape.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index 7d852b57b1386..b49eefb0fabd8 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.h +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -34,12 +34,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/semantic_version.h" #include "xla/xla.pb.h" diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 29f2c2133e9b4..4a57c241f6336 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/call_inliner.h" #include "xla/service/dump.h" @@ -45,7 +46,6 @@ limitations under the License. #include "xla/service/gpu/transforms/gemm_fusion.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_pass_pipeline.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" diff --git a/xla/service/hlo_pass_fix.h b/xla/service/hlo_pass_fix.h index cedd856e99b97..c7dab4303b6e1 100644 --- a/xla/service/hlo_pass_fix.h +++ b/xla/service/hlo_pass_fix.h @@ -16,121 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_PASS_FIX_H_ #define XLA_SERVICE_HLO_PASS_FIX_H_ -#include -#include - -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/status_macros.h" -#include "xla/types.h" - -namespace xla { - -// Do an HLO pass to a fix point. -template -class HloPassFix : public Pass { - public: - static_assert(std::is_base_of::value, - "Pass must be a subclass of HloPassInterface"); - using RunState = HloPassInterface::RunState; - template - explicit HloPassFix(Args&&... args) : Pass(args...) {} - - absl::Status RunOnChangedComputations( - HloModule* module, RunState* outer_run_state, - const absl::flat_hash_set& execution_threads) - override { - RunState run_state; - run_state.changed_last_iteration = outer_run_state->changed_last_iteration; - TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads)); - outer_run_state->changed_this_iteration.insert(run_state.changed.begin(), - run_state.changed.end()); - return absl::OkStatus(); - } - - using HloPassInterface::Run; - absl::StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { - RunState run_state(module); - TF_RETURN_IF_ERROR(RunToFixPoint(module, &run_state, execution_threads)); - return !run_state.changed.empty(); - } - - using HloPassInterface::RunOnModuleGroup; - absl::StatusOr RunOnModuleGroup( - HloModuleGroup* module_group, - const absl::flat_hash_set& execution_threads) - override { - bool changed = false; - bool changed_this_iteration = true; - int64_t iteration_count = 0; - VLOG(3) << "Running HloPassFix."; - while (changed_this_iteration) { - TF_ASSIGN_OR_RETURN( - changed_this_iteration, - Pass::RunOnModuleGroup(module_group, execution_threads)); - changed |= changed_this_iteration; - VLOG(3) << "changed_this_iteration: " << changed_this_iteration; - ++iteration_count; - if (iteration_count == kIterationLimit) { - VLOG(1) << "Unexpectedly high number of iterations in HLO passes, " - "exiting fixed point loop."; - // Return false in case this is fixed point is nested. - return false; - } - } - return changed; - } - - private: - absl::Status RunToFixPoint( - HloModule* module, RunState* run_state, - const absl::flat_hash_set& execution_threads) { - VLOG(3) << "Running HloPassFix on " << Pass::name(); - while (!run_state->changed_last_iteration.empty()) { - TF_RETURN_IF_ERROR( - RunOnChangedComputationsOnce(module, run_state, execution_threads)); - VLOG(3) << Pass::name() << " iteration " << run_state->iteration - << " changed_this_iteration: " - << !run_state->changed_last_iteration.empty(); - run_state->IncrementIteration(); - if (run_state->iteration == kIterationLimit) { - VLOG(1) << "Unexpectedly high number of iterations in HLO passes '" - << Pass::name() << "' for module '" << module->name() - << "'. Exiting fixed point loop."; - // Clear changed and abort in case this is fixed point is nested. - run_state->changed.clear(); - break; - } - } - return absl::OkStatus(); - } - - absl::Status RunOnChangedComputationsOnce( - HloModule* module, RunState* run_state, - const absl::flat_hash_set& execution_threads) { - // If Pass overrides RunOnChangedComputations, just forward to it. - if (!std::is_same::value) { - return Pass::RunOnChangedComputations(module, run_state, - execution_threads); - } - // If Pass does not override the default - // HloPassInterface::RunOnChangedComputations that calls into - // HloPassFix::Run, avoid infinite recursion. - TF_ASSIGN_OR_RETURN(bool changed, Pass::Run(module, execution_threads)); - if (changed) { - auto computations = module->computations(execution_threads); - run_state->changed_this_iteration.insert(computations.begin(), - computations.end()); - } - return absl::OkStatus(); - } -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/pass/hlo_pass_fix.h" #endif // XLA_SERVICE_HLO_PASS_FIX_H_ diff --git a/xla/service/hlo_pass_interface.h b/xla/service/hlo_pass_interface.h index fa233bcc0c262..1b6a373b3a178 100644 --- a/xla/service/hlo_pass_interface.h +++ b/xla/service/hlo_pass_interface.h @@ -16,169 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_PASS_INTERFACE_H_ #define XLA_SERVICE_HLO_PASS_INTERFACE_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/status_macros.h" -#include "xla/types.h" - -namespace xla { - -// Base class for HLO passes. These are used with the HloPassPipeline to -// organize a sequence of passes. An HLO pass should not extend this class -// directly; it should extend HloModulePass or HloModuleGroupPass. -class HloPassInterface { - public: - // Struct that holds states of pass runs across multiple iterations. - struct RunState { - // The current iteration number. - int iteration = 0; - // Set of all changed computations from all pass runs using this state. - absl::flat_hash_set changed; - // Set of changed computation from previous iteration. - absl::flat_hash_set changed_last_iteration; - // Set of changed computation from current iteration. - absl::flat_hash_set changed_this_iteration; - - RunState() = default; - explicit RunState(HloModule* module) - : changed_last_iteration(module->computations().begin(), - module->computations().end()) {} - - // Transition to the next iteration. - // - // Depending on the pass implmentation, one iteration includes all the work - // done between two IncrementIteration calls, there can be arbitrary number - // of passes that ran arbitrary times with this state. - void IncrementIteration() { - using std::swap; - changed.insert(changed_this_iteration.begin(), - changed_this_iteration.end()); - swap(changed_last_iteration, changed_this_iteration); - changed_this_iteration.clear(); - ++iteration; - } - }; - virtual ~HloPassInterface() = default; - virtual absl::string_view name() const = 0; - - // Run the pass on the given HLO module with specified execution_threads. - // Empty execution_threads list means all execution_threads are included. - // Returns whether it modified the module. Note that due to C++ inheritance - // hides overloaded function, Run(HloModule* module) is not a member function - // of a subclass unless it's explicitly brought to the subclass besides - // implementing the virtual version, for instance, - // - // class MyNewPass : public HloModulePass { - // public: - // MyNewPass(); - // absl::string_view name() const override { return "my-new-pass"; } - // - // using HloPassInterface::Run; - // absl::StatusOr Run( - // HloModule* module, - // const absl::flat_hash_set& execution_threads) - // override; - // }; - // - absl::StatusOr Run(HloModule* module) { - return Run(module, /*execution_threads=*/{}); - } - virtual absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) = 0; - - // Run the pass on computation on changed computations from last iteration in - // given HLO module for specified execution_threads, with caller provided - // RunState which holds the state information across multiple iterations. - // - // NOTE: This is a temporary default implementation that conservatively treats - // all computations as changed. Eventually all passes should override this - // method instead of Run() and Run() will call into this method instead. - virtual absl::Status RunOnChangedComputations( - HloModule* module, RunState* run_state, - const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(bool changed, Run(module, execution_threads)); - if (changed) { - auto computations = module->computations(execution_threads); - run_state->changed_this_iteration.insert(computations.begin(), - computations.end()); - } - return absl::OkStatus(); - } - - // Run the pass on the given HLO module group for specified - // `execution_threads`. Empty `execution_threads` list means all execution - // threads are included. Returns whether it modified the module group. - // Ideally, the module group variant would be named "Run" as well, but C++ - // does not handle overloaded virtual methods well. - // - // Note that due to C++ inheritance hides overloaded function, - // RunOnModuleGroup(HloModuleGroup* module_group) is not a member function of - // a subclass unless it's explicitly brought to the subclass besides - // implementing the virtual version, for instance, - // - // class MyNewPass : public HloModuleGroupPass { - // public: - // MyNewPass(); - // absl::string_view name() const override { return "my-new-pass"; } - // - // using HloPassInterface::RunOnModuleGroup; - // absl::StatusOr RunOnModuleGroup( - // HloModuleGroup* module_group, - // const absl::flat_hash_set& execution_threads) - // override; - // }; - // - absl::StatusOr RunOnModuleGroup(HloModuleGroup* module_group) { - return RunOnModuleGroup(module_group, /*execution_threads=*/{}); - } - virtual absl::StatusOr RunOnModuleGroup( - HloModuleGroup* module_group, - const absl::flat_hash_set& execution_threads) = 0; - - virtual bool IsPassPipeline() { return false; } -}; - -// Base class for passes which are module-scoped. -class HloModulePass : public HloPassInterface { - public: - // Runs the pass on a module group by iterating through each module in the - // group. - absl::StatusOr RunOnModuleGroup( - HloModuleGroup* module_group, - const absl::flat_hash_set& execution_threads) - override { - bool changed = false; - for (HloModule* module : module_group->modules()) { - TF_ASSIGN_OR_RETURN(bool module_changed, Run(module, execution_threads)); - changed |= module_changed; - } - return changed; - }; - - // Update the layout of a Shape to one that is supported by a given backend. - // One can call this function after modifying the Shape in case that modifying - // the Shape requires changes to the layout for the given Backend. - // - // TODO(b/129084868): Make this Backend dependent instead of requiring - // deriving from the pass and overriding this function. - virtual void UpdateLayout(Shape* shape) {} -}; - -// Base class for passes which are module-group scoped. These passes cannot run -// on an HLO module. -class HloModuleGroupPass : public HloPassInterface { - public: - absl::StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { - return Internal("Module group pass cannot be run on a module"); - } -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/pass/hlo_pass_interface.h" #endif // XLA_SERVICE_HLO_PASS_INTERFACE_H_ diff --git a/xla/service/hlo_pass_pipeline.h b/xla/service/hlo_pass_pipeline.h index 4c07de1ac0caf..83d693ccfef3f 100644 --- a/xla/service/hlo_pass_pipeline.h +++ b/xla/service/hlo_pass_pipeline.h @@ -16,156 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_PASS_PIPELINE_H_ #define XLA_SERVICE_HLO_PASS_PIPELINE_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/compilation_stats.h" -#include "xla/service/hlo_pass_interface.h" -#include "xla/types.h" - -namespace xla { - -class PhaseOrderPipeline; - -// Pipeline of HLO passes. -class HloPassPipeline : public HloPassInterface { - public: - explicit HloPassPipeline(const std::string& name, - CompilationStats* compilation_stats = nullptr) - : name_(name), compilation_stats_(compilation_stats) { - if (compilation_stats == nullptr) { - empty_compilation_stats_ = CompilationStats::MakeNoopStats(); - compilation_stats_ = empty_compilation_stats_.get(); - } - } - absl::string_view name() const override { return name_; } - - // Add a pass to the pipeline. It should be called with the arguments for the - // pass constructor: - // - // pipeline.AddPass(constructor_arg1, constructor_arg2); - // - // Returns a reference to the added pass. - template - T& AddPass(Args&&... args) { - CHECK(!run_called_) << "AddPass cannot be called after Run"; - auto pass = new T(std::forward(args)...); - passes_.push_back(std::unique_ptr(pass)); - return *pass; - } - - // Add an invariant-checking pass to the pipeline. It will be run before and - // after each HLO pass. The invariant checking pass must not mutate the graph - // (it is required to always return "false" from its Run() method). - template - T& AddInvariantChecker(Args&&... args) { - CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run"; - auto pass = new T(std::forward(args)...); - invariant_checkers_.push_back(std::unique_ptr(pass)); - return *pass; - } - - // Add an invariant-checking pass to the pipeline on debug builds only. - template - void AddInvariantCheckerDebug(Args&&... args) { -#ifndef NDEBUG - AddInvariantChecker(std::forward(args)...); -#endif // NDEBUG - } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - using HloPassInterface::RunOnModuleGroup; - absl::StatusOr RunOnModuleGroup( - HloModuleGroup* module_group, - const absl::flat_hash_set& execution_threads) override; - - bool IsPassPipeline() override { return true; } - - // Return size of passes_. - int PassesSize() { return passes_.size(); } - // Return reference to pass specified by index. - HloPassInterface& GetPass(int index) { return *passes_[index]; } - - private: - // Returns the set of passes which are enabled. DebugOptions can selectively - // disable passes via --xla_disable_hlo_passes flag. - std::vector GetEnabledPasses( - const DebugOptions& debug_options); - - // Maybe dumps the given module or module group depending on flag values - // contained in DebugOptions of module config. If it is dumped, saves the - // filenames of the dumps into module metadata. - void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group, - absl::string_view after_pass_name, - absl::string_view before_pass_name); - void MaybeDumpHloAndSaveFilenames(HloModule& module, - absl::string_view after_pass_name, - absl::string_view before_pass_name); - - // Runs the invariant checker on the given HLO for specified - // `execution_threads`. Empty `execution_threads` means all execution threads - // are included. HloT can be either HloModule or HloModuleGroup. - template - absl::Status RunInvariantCheckers(HloT* hlo, - absl::string_view after_pass_name) { - return RunInvariantCheckers(hlo, after_pass_name, /*execution_threads=*/{}); - } - template - absl::Status RunInvariantCheckers( - HloT* hlo, absl::string_view after_pass_name, - const absl::flat_hash_set& execution_threads); - - // Helper which runs the given pass on the given HLO. HloT can be either - // HloModule or HloModuleGroup. - template - absl::StatusOr RunPassesInternal( - HloT* hlo, const DebugOptions& debug_options, - const absl::flat_hash_set& execution_threads); - - // Helpers which run the given passes on the given HLO construct. Only - // computations with specified `execution_threads` are considered by the pass, - // empty thread list means all `execution_threads` are considered. These - // helpers enable templating of the core of the pipeline logic by providing - // HloModule and HloModuleGroup specific methods with the same name. - static absl::StatusOr RunHelper( - HloPassInterface* pass, HloModule* module, - const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module, execution_threads)); - module->Cleanup(); - return changed; - } - static absl::StatusOr RunHelper( - HloPassInterface* pass, HloModuleGroup* module_group, - const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN( - bool changed, pass->RunOnModuleGroup(module_group, execution_threads)); - module_group->Cleanup(); - return changed; - } - - const std::string name_; - std::vector> passes_; - std::vector> invariant_checkers_; - bool run_called_ = false; - - CompilationStats* compilation_stats_; - // Default stats instance for when one is not passed in the constructor. - // Use via compilation_stats_, not directly. - std::unique_ptr empty_compilation_stats_; - - // Allow PhaseOrderPipeline to modify private passes_ member in order to - // perform PhaseOrdering. - friend class ::xla::PhaseOrderPipeline; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/pass/hlo_pass_pipeline.h" #endif // XLA_SERVICE_HLO_PASS_PIPELINE_H_