From 6947deebd96d7b3da5d3cd2fb4ec41acd5e61abe Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 2 Oct 2024 13:32:02 -0700 Subject: [PATCH] [XLA] Fix typos in comments and clean up includes in HloRematerialization. PiperOrigin-RevId: 681573006 --- xla/service/BUILD | 3 ++ xla/service/hlo_rematerialization.cc | 47 ++++++++++++++++++---------- xla/service/hlo_rematerialization.h | 18 ++++++----- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index e72fff7038d06..d1926ad651224 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -5510,7 +5510,10 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:numbers", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/hlo_rematerialization.cc b/xla/service/hlo_rematerialization.cc index d7090605bbc83..9afba2b246f9a 100644 --- a/xla/service/hlo_rematerialization.cc +++ b/xla/service/hlo_rematerialization.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/hlo_rematerialization.h" #include +#include #include #include #include @@ -39,6 +40,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" @@ -48,16 +51,21 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/layout_util.h" #include "xla/map_util.h" +#include "xla/service/call_graph.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_dce.h" #include "xla/service/logical_buffer.h" +#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/numbers.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -66,8 +74,7 @@ namespace { using ::tsl::strings::HumanReadableNumBytes; // Potential optimizations: -// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue -// of candidates. +// . Avoid N^2 behavior by keeping a priority queue of candidates. // . Cache IsRematerializable in Item? Only correct if control // predecessors and successors don't change. @@ -660,7 +667,7 @@ class MemoryUsageTracker { const HloRematerialization::Options& options() const { return options_; } - // Check invariants of the data structure. This is expensive to call. + // Checks invariants of the data structure. This is expensive to call. bool Check() const; std::string ToString() const; @@ -710,19 +717,19 @@ class MemoryUsageTracker { } }; - // Adjust our tracked memory usage as a result of this new item coming into + // Adjusts our tracked memory usage as a result of this new item coming into // scope. void CountAllocatedMemory(Item* item); - // Adjust our tracked memory usage as a result of this item going out of + // Adjusts our tracked memory usage as a result of this item going out of // scope. absl::Status CountFreedMemory(Item* item); - // Buffers have users and users have buffers used, this function resolves + // Buffers have users and users have buffers used. This function resolves // outstanding issues in that bidirectional dependency. void ReplaceUsesInUsersOfBuffer(Buffer& buffer, BufferId old_id) const; - // Get the compact shape of given hlo instruction. An internal cache is used + // Gets the compact shape of given hlo instruction. An internal cache is used // to avoid computing the shape multiple times. absl::StatusOr GetCompactShape(const HloInstruction* hlo); @@ -739,7 +746,7 @@ class MemoryUsageTracker { std::move(users), live_out, has_indirect_uses); } - // Create a new buffer representing a rematerialization of given buffer for + // Creates a new buffer representing a rematerialization of given buffer for // the given uses. Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, UsesList&& rematerialized_uses) { @@ -755,10 +762,10 @@ class MemoryUsageTracker { /*has_indirect_uses=*/false); } - // Return number of bytes allocated for the buffer with the given id. Buffers - // allocated by the calling computation (eg, parameter and output buffers) are - // considered to have zero bytes because the memory is accounted for in a - // different computation. + // Returns the number of bytes allocated for the buffer with the given id. + // Buffers allocated by the calling computation (eg, parameter and output + // buffers) are considered to have zero bytes because the memory is accounted + // for in a different computation. int64_t AllocatedSize(BufferId buffer_id) const { const Buffer& buffer = buffers_.at(buffer_id); HloInstruction* inst = buffer.defining_instruction->instruction; @@ -776,8 +783,8 @@ class MemoryUsageTracker { } } - // Returns true if BeginInstruction and EndInstruction has been called for the - // given instruction. + // Returns whether BeginInstruction and EndInstruction have been called for + // the given instruction. bool IsFinished(Item* item) const { return item->placed && item != in_progress_item_; } @@ -815,7 +822,7 @@ class MemoryUsageTracker { return false; } - // Create a new buffer, add it to buffers_, and return a reference. + // Creates a new buffer, adds it to buffers_, and returns a reference. Buffer& NewBuffer(Item* defining_instruction, const Shape& shape, const ShapeIndex& index, UsesList&& uses, bool live_out, bool has_indirect_uses) { @@ -1893,7 +1900,7 @@ MemoryUsageTracker::PickRematerializationCandidates( continue; } - // First, calculate the cost of compression rematerialziation for this + // First, calculate the cost of compression rematerialization for this // instruction. if (options_.remat_mode_config.compress && block.size() == 1) { auto cost = @@ -1989,6 +1996,8 @@ UsesList MemoryUsageTracker::GetItemUses(Item* item) const { return combined_users; } +// Performs the rematerialization of all items in `best_items` and returns the +// number of net instructions added. absl::StatusOr RematerializeInstructions( MemoryUsageTracker* memory_tracker, std::vector* best_items, absl::flat_hash_set* remat_move_instructions, @@ -2174,6 +2183,8 @@ absl::StatusOr RematerializeInstructions( return net_instructions_added; } +// Performs rematerialization of `best_item` via the compression strategy. +// Returns the net number of instructions added. absl::StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, Item* best_item, const Shape& compact_shape, @@ -2224,9 +2235,12 @@ absl::StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, instruction_list->InsertBeforeInstructions(uncompressed_item, place_before); instruction_list->InsertAfterInstructions(compressed_item, {best_item}); + // Net two instructions added. return 2; } +// Performs rematerialization of `best_item` via the host offload strategy. +// Returns the net number of instructions added. absl::StatusOr OffloadInstruction(MemoryUsageTracker* memory_tracker, Item* best_item, InstructionList* instruction_list) { @@ -2486,6 +2500,7 @@ absl::StatusOr OffloadInstruction(MemoryUsageTracker* memory_tracker, best_item, copy_start_to_host_item, copy_done_to_host_item, copy_start_to_device_item, copy_done_to_device_item)); + // Net four instructions added. return 4; } diff --git a/xla/service/hlo_rematerialization.h b/xla/service/hlo_rematerialization.h index 4eba8f1a2bdc8..74960976aa1d6 100644 --- a/xla/service/hlo_rematerialization.h +++ b/xla/service/hlo_rematerialization.h @@ -15,6 +15,9 @@ #ifndef XLA_SERVICE_HLO_REMATERIALIZATION_H_ #define XLA_SERVICE_HLO_REMATERIALIZATION_H_ +#include +#include +#include #include #include @@ -62,8 +65,8 @@ class HloRematerialization : public HloModulePass { : recompute(recompute), compress(compress), host_offload(host_offload) {} - bool recompute; // Enables the kCompress RematStrategy. - bool compress; // Enables the kRecompute RematStrategy. + bool recompute; // Enables the kRecompute RematStrategy. + bool compress; // Enables the kCompress RematStrategy. bool host_offload; // Enables the kHostOffload RematStrategy. }; @@ -180,9 +183,9 @@ class HloRematerialization : public HloModulePass { protected: // Rematerializes instructions within the given computation. 'schedule' - // constains the order in which the computation's instructions will be emitted - // in the backend. Rematerialized instructions will be added to the HLO - // computation and inserted into 'schedule'. + // constrains the order in which the computation's instructions will be + // emitted in the backend. Rematerialized instructions will be added to the + // HLO computation and inserted into 'schedule'. virtual absl::StatusOr RematerializeComputation( HloComputation* computation, HloSchedule* schedule, int64_t memory_limit_bytes, int64_t min_remat_size, @@ -212,9 +215,8 @@ class HloRematerialization : public HloModulePass { std::unique_ptr call_graph_; // The peak memory usage of each computation. The map contains only those - // computations called from sequential context - // (CallContext::kSequential). These values are updated as rematerialization - // occurs. + // computations called from sequential context (CallContext::kSequential). + // These values are updated as rematerialization occurs. absl::flat_hash_map computation_peak_memory_; std::unique_ptr points_to_analysis_;