Skip to content

Commit

Permalink
[XLA] Fix typos in comments and clean up includes in HloRematerializa…
Browse files Browse the repository at this point in the history
…tion.

PiperOrigin-RevId: 681573006
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
1 parent 13e486d commit 6947dee
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
3 changes: 3 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5510,7 +5510,10 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:numbers",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
47 changes: 31 additions & 16 deletions xla/service/hlo_rematerialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/service/hlo_rematerialization.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <limits>
Expand All @@ -39,6 +40,8 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_clone_context.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand All @@ -48,16 +51,21 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/layout_util.h"
#include "xla/map_util.h"
#include "xla/service/call_graph.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_dataflow_analysis.h"
#include "xla/service/hlo_dce.h"
#include "xla/service/logical_buffer.h"
#include "xla/service/tuple_points_to_analysis.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/numbers.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand All @@ -66,8 +74,7 @@ namespace {
using ::tsl::strings::HumanReadableNumBytes;

// Potential optimizations:
// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
// of candidates.
// . Avoid N^2 behavior by keeping a priority queue of candidates.
// . Cache IsRematerializable in Item? Only correct if control
// predecessors and successors don't change.

Expand Down Expand Up @@ -660,7 +667,7 @@ class MemoryUsageTracker {

const HloRematerialization::Options& options() const { return options_; }

// Check invariants of the data structure. This is expensive to call.
// Checks invariants of the data structure. This is expensive to call.
bool Check() const;

std::string ToString() const;
Expand Down Expand Up @@ -710,19 +717,19 @@ class MemoryUsageTracker {
}
};

// Adjust our tracked memory usage as a result of this new item coming into
// Adjusts our tracked memory usage as a result of this new item coming into
// scope.
void CountAllocatedMemory(Item* item);

// Adjust our tracked memory usage as a result of this item going out of
// Adjusts our tracked memory usage as a result of this item going out of
// scope.
absl::Status CountFreedMemory(Item* item);

// Buffers have users and users have buffers used, this function resolves
// Buffers have users and users have buffers used. This function resolves
// outstanding issues in that bidirectional dependency.
void ReplaceUsesInUsersOfBuffer(Buffer& buffer, BufferId old_id) const;

// Get the compact shape of given hlo instruction. An internal cache is used
// Gets the compact shape of given hlo instruction. An internal cache is used
// to avoid computing the shape multiple times.
absl::StatusOr<const Shape*> GetCompactShape(const HloInstruction* hlo);

Expand All @@ -739,7 +746,7 @@ class MemoryUsageTracker {
std::move(users), live_out, has_indirect_uses);
}

// Create a new buffer representing a rematerialization of given buffer for
// Creates a new buffer representing a rematerialization of given buffer for
// the given uses.
Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
UsesList&& rematerialized_uses) {
Expand All @@ -755,10 +762,10 @@ class MemoryUsageTracker {
/*has_indirect_uses=*/false);
}

// Return number of bytes allocated for the buffer with the given id. Buffers
// allocated by the calling computation (eg, parameter and output buffers) are
// considered to have zero bytes because the memory is accounted for in a
// different computation.
// Returns the number of bytes allocated for the buffer with the given id.
// Buffers allocated by the calling computation (eg, parameter and output
// buffers) are considered to have zero bytes because the memory is accounted
// for in a different computation.
int64_t AllocatedSize(BufferId buffer_id) const {
const Buffer& buffer = buffers_.at(buffer_id);
HloInstruction* inst = buffer.defining_instruction->instruction;
Expand All @@ -776,8 +783,8 @@ class MemoryUsageTracker {
}
}

// Returns true if BeginInstruction and EndInstruction has been called for the
// given instruction.
// Returns whether BeginInstruction and EndInstruction have been called for
// the given instruction.
bool IsFinished(Item* item) const {
return item->placed && item != in_progress_item_;
}
Expand Down Expand Up @@ -815,7 +822,7 @@ class MemoryUsageTracker {
return false;
}

// Create a new buffer, add it to buffers_, and return a reference.
// Creates a new buffer, adds it to buffers_, and returns a reference.
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
const ShapeIndex& index, UsesList&& uses, bool live_out,
bool has_indirect_uses) {
Expand Down Expand Up @@ -1893,7 +1900,7 @@ MemoryUsageTracker::PickRematerializationCandidates(
continue;
}

// First, calculate the cost of compression rematerialziation for this
// First, calculate the cost of compression rematerialization for this
// instruction.
if (options_.remat_mode_config.compress && block.size() == 1) {
auto cost =
Expand Down Expand Up @@ -1989,6 +1996,8 @@ UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
return combined_users;
}

// Performs the rematerialization of all items in `best_items` and returns the
// number of net instructions added.
absl::StatusOr<int64_t> RematerializeInstructions(
MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
Expand Down Expand Up @@ -2174,6 +2183,8 @@ absl::StatusOr<int64_t> RematerializeInstructions(
return net_instructions_added;
}

// Performs rematerialization of `best_item` via the compression strategy.
// Returns the net number of instructions added.
absl::StatusOr<int64_t> CompressInstruction(MemoryUsageTracker* memory_tracker,
Item* best_item,
const Shape& compact_shape,
Expand Down Expand Up @@ -2224,9 +2235,12 @@ absl::StatusOr<int64_t> CompressInstruction(MemoryUsageTracker* memory_tracker,
instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
instruction_list->InsertAfterInstructions(compressed_item, {best_item});

// Net two instructions added.
return 2;
}

// Performs rematerialization of `best_item` via the host offload strategy.
// Returns the net number of instructions added.
absl::StatusOr<int64_t> OffloadInstruction(MemoryUsageTracker* memory_tracker,
Item* best_item,
InstructionList* instruction_list) {
Expand Down Expand Up @@ -2486,6 +2500,7 @@ absl::StatusOr<int64_t> OffloadInstruction(MemoryUsageTracker* memory_tracker,
best_item, copy_start_to_host_item, copy_done_to_host_item,
copy_start_to_device_item, copy_done_to_device_item));

// Net four instructions added.
return 4;
}

Expand Down
18 changes: 10 additions & 8 deletions xla/service/hlo_rematerialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#ifndef XLA_SERVICE_HLO_REMATERIALIZATION_H_
#define XLA_SERVICE_HLO_REMATERIALIZATION_H_

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <utility>

Expand Down Expand Up @@ -62,8 +65,8 @@ class HloRematerialization : public HloModulePass {
: recompute(recompute),
compress(compress),
host_offload(host_offload) {}
bool recompute; // Enables the kCompress RematStrategy.
bool compress; // Enables the kRecompute RematStrategy.
bool recompute; // Enables the kRecompute RematStrategy.
bool compress; // Enables the kCompress RematStrategy.
bool host_offload; // Enables the kHostOffload RematStrategy.
};

Expand Down Expand Up @@ -180,9 +183,9 @@ class HloRematerialization : public HloModulePass {

protected:
// Rematerializes instructions within the given computation. 'schedule'
// constains the order in which the computation's instructions will be emitted
// in the backend. Rematerialized instructions will be added to the HLO
// computation and inserted into 'schedule'.
// constrains the order in which the computation's instructions will be
// emitted in the backend. Rematerialized instructions will be added to the
// HLO computation and inserted into 'schedule'.
virtual absl::StatusOr<bool> RematerializeComputation(
HloComputation* computation, HloSchedule* schedule,
int64_t memory_limit_bytes, int64_t min_remat_size,
Expand Down Expand Up @@ -212,9 +215,8 @@ class HloRematerialization : public HloModulePass {
std::unique_ptr<CallGraph> call_graph_;

// The peak memory usage of each computation. The map contains only those
// computations called from sequential context
// (CallContext::kSequential). These values are updated as rematerialization
// occurs.
// computations called from sequential context (CallContext::kSequential).
// These values are updated as rematerialization occurs.
absl::flat_hash_map<const HloComputation*, int64_t> computation_peak_memory_;

std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
Expand Down

0 comments on commit 6947dee

Please sign in to comment.