Skip to content

Commit

Permalink
[XLA:TPU] Add LoopOptimizerBestFitHeap class that models alternate me…
Browse files Browse the repository at this point in the history
…mory for memory bound loops and accounts for fragmentation.

PiperOrigin-RevId: 675017830
  • Loading branch information
subhankarshah authored and Google-ML-Automation committed Sep 16, 2024
1 parent af733ec commit 30bf2dd
Show file tree
Hide file tree
Showing 4 changed files with 588 additions and 0 deletions.
2 changes: 2 additions & 0 deletions xla/service/memory_space_assignment/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ cc_library(
"//xla/service:hlo_buffer",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_value",
"//xla/service/heap_simulator",
"//xla/service/heap_simulator:allocation_block",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
253 changes: 253 additions & 0 deletions xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_live_range.h"
#include "xla/service/buffer_value.h"
#include "xla/service/heap_simulator/allocation_block.h"
#include "xla/service/heap_simulator/heap_simulator.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_alias_analysis.h"
#include "xla/service/hlo_buffer.h"
Expand Down Expand Up @@ -71,6 +73,257 @@ std::optional<int64_t> GetInstructionIndex(

} // namespace

void LoopOptimizerBestFitHeap::CreateBufferInterval(
const AllocationBlock& allocation_block,
const AllocationBlock* colocated_with) {
buffer_intervals_[&allocation_block] =
BufferInterval({&allocation_block,
allocation_block.size,
allocation_block.inclusive_start_time,
allocation_block.end_time,
{},
colocated_with == nullptr});
if (colocated_with) {
buffer_intervals_[colocated_with].colocations.push_back(&allocation_block);
}
}

std::optional<HeapSimulator::Chunk>
LoopOptimizerBestFitHeap::MaybeFindChunkCandidate(
const AllocationBlock& allocation_block, int64_t preferred_offset) {
Chunk chunk_candidate = FindChunkCandidate(
buffer_intervals_[&allocation_block], preferred_offset);
if (chunk_candidate.chunk_end() <= size_limit_per_heap_) {
return chunk_candidate;
}
return std::nullopt;
}

std::optional<HeapSimulator::Chunk>
LoopOptimizerBestFitHeap::FindAndCommitChunkCandidate(
const AllocationBlock& allocation_block, int64_t preferred_offset) {
std::optional<Chunk> chunk =
MaybeFindChunkCandidate(allocation_block, preferred_offset);
if (chunk.has_value()) {
CommitChunk(buffer_intervals_[&allocation_block], chunk.value());
}
return chunk;
}

void LoopOptimizerBestFitHeap::RemoveChunk(int64_t start_time, int64_t end_time,
Chunk chunk) {
CHECK(interval_tree_.Remove(start_time, end_time, chunk));
}

void LoopOptimizerBestFitHeap::RemoveEvenChunks(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop,
std::optional<HeapSimulator::Chunk>& chunk) {
RemoveChunk(begin_idx_in_loop, end_idx_in_loop, chunk.value());
RemoveChunk(begin_idx_in_loop + 2 * loop_size_,
end_idx_in_loop + 2 * loop_size_, chunk.value());
}

void LoopOptimizerBestFitHeap::RemoveOddChunks(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop,
std::optional<HeapSimulator::Chunk>& chunk) {
RemoveChunk(begin_idx_in_loop + loop_size_, end_idx_in_loop + loop_size_,
chunk.value());
RemoveChunk(begin_idx_in_loop + 3 * loop_size_,
end_idx_in_loop + 3 * loop_size_, chunk.value());
}

void LoopOptimizerBestFitHeap::RemoveEvenOddChunkPair(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop,
EvenOddChunkPair& chunks) {
CheckAllocationIntervalValid(begin_idx_in_loop, end_idx_in_loop);
ShiftAllocationIntervalIfRequired(begin_idx_in_loop, end_idx_in_loop);
auto [even_chunk, odd_chunk] = chunks;
RemoveEvenChunks(begin_idx_in_loop, end_idx_in_loop, even_chunk);
RemoveOddChunks(begin_idx_in_loop, end_idx_in_loop, odd_chunk);
}

const AllocationBlock& LoopOptimizerBestFitHeap::GetAllocationBlock(
int64_t start_time, int64_t end_time, int64_t size) {
allocation_blocks_.push_back(
{start_time, end_time, size, static_cast<int64_t>(-1),
static_cast<int64_t>(-1),
static_cast<int64_t>(allocation_blocks_.size())});
return allocation_blocks_.back();
}

const AllocationBlock& LoopOptimizerBestFitHeap::CreateEvenAllocationBlock(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop, int64_t size) {
const AllocationBlock& first_allocation_block =
GetAllocationBlock(begin_idx_in_loop, end_idx_in_loop, size);
CreateBufferInterval(first_allocation_block);
const AllocationBlock& second_allocation_block =
GetAllocationBlock(begin_idx_in_loop + 2 * loop_size_,
end_idx_in_loop + 2 * loop_size_, size);
CreateBufferInterval(second_allocation_block, &first_allocation_block);
return first_allocation_block;
}

const AllocationBlock& LoopOptimizerBestFitHeap::CreateOddAllocationBlock(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop, int64_t size) {
const AllocationBlock& first_allocation_block = GetAllocationBlock(
begin_idx_in_loop + loop_size_, end_idx_in_loop + loop_size_, size);
CreateBufferInterval(first_allocation_block);
const AllocationBlock& second_allocation_block =
GetAllocationBlock(begin_idx_in_loop + 3 * loop_size_,
end_idx_in_loop + 3 * loop_size_, size);
CreateBufferInterval(second_allocation_block, &first_allocation_block);
return first_allocation_block;
}

void LoopOptimizerBestFitHeap::CheckAllocationIntervalValid(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop) const {
CHECK_LE(begin_idx_in_loop, end_idx_in_loop);
CHECK_LE(-1 * loop_size_, begin_idx_in_loop);
CHECK_LT(begin_idx_in_loop, loop_size_);
CHECK_LE(0, end_idx_in_loop);
CHECK_LT(end_idx_in_loop, 2 * loop_size_);
CHECK_LE(end_idx_in_loop - begin_idx_in_loop + 1, 2 * loop_size_);
}

void LoopOptimizerBestFitHeap::ShiftAllocationIntervalIfRequired(
int64_t& begin_idx_in_loop, int64_t& end_idx_in_loop) const {
if (begin_idx_in_loop < 0) {
begin_idx_in_loop += loop_size_;
end_idx_in_loop += loop_size_;
}
}

EvenOddChunkPair LoopOptimizerBestFitHeap::FindEvenAndOddAllocationBetween(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop, int64_t size,
std::pair<int64_t, int64_t> preferred_offsets) {
CheckAllocationIntervalValid(begin_idx_in_loop, end_idx_in_loop);
ShiftAllocationIntervalIfRequired(begin_idx_in_loop, end_idx_in_loop);
auto [even_offset, odd_offset] = preferred_offsets;
const AllocationBlock& even_allocation =
CreateEvenAllocationBlock(begin_idx_in_loop, end_idx_in_loop, size);
const AllocationBlock& odd_allocation =
CreateOddAllocationBlock(begin_idx_in_loop, end_idx_in_loop, size);
// We need to commit the even chunk because even and odd chunks might overlap
// in time.
std::optional<HeapSimulator::Chunk> even_chunk =
FindAndCommitChunkCandidate(even_allocation, even_offset);
if (!even_chunk.has_value()) {
return {std::nullopt, std::nullopt};
}
std::optional<HeapSimulator::Chunk> odd_chunk =
MaybeFindChunkCandidate(odd_allocation, odd_offset);
RemoveEvenChunks(begin_idx_in_loop, end_idx_in_loop, even_chunk);
if (odd_chunk.has_value()) {
return {even_chunk, odd_chunk};
}
return {std::nullopt, std::nullopt};
}

EvenOddChunkPair LoopOptimizerBestFitHeap::AllocateEvenAndOddBetween(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop, int64_t size,
std::pair<int64_t, int64_t> preferred_offsets) {
CheckAllocationIntervalValid(begin_idx_in_loop, end_idx_in_loop);
ShiftAllocationIntervalIfRequired(begin_idx_in_loop, end_idx_in_loop);
auto [even_offset, odd_offset] = preferred_offsets;
const AllocationBlock& even_allocation =
CreateEvenAllocationBlock(begin_idx_in_loop, end_idx_in_loop, size);
const AllocationBlock& odd_allocation =
CreateOddAllocationBlock(begin_idx_in_loop, end_idx_in_loop, size);
// We need to commit the even chunk because even and odd chunks might overlap
// in time.
std::optional<HeapSimulator::Chunk> even_chunk =
FindAndCommitChunkCandidate(even_allocation, even_offset);
if (!even_chunk.has_value()) {
return {std::nullopt, std::nullopt};
}
std::optional<HeapSimulator::Chunk> odd_chunk =
FindAndCommitChunkCandidate(odd_allocation, odd_offset);
if (odd_chunk.has_value()) {
return {even_chunk, odd_chunk};
}
// Remove even chunk if odd chunk was not found.
RemoveEvenChunks(begin_idx_in_loop, end_idx_in_loop, even_chunk);
return {std::nullopt, std::nullopt};
}

const AllocationBlock&
LoopOptimizerBestFitHeap::CreateSameEvenAndOddAllocationBlock(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop, int64_t size) {
const AllocationBlock& first_allocation_block =
GetAllocationBlock(begin_idx_in_loop, end_idx_in_loop, size);
CreateBufferInterval(first_allocation_block);
const AllocationBlock& second_allocation_block =
GetAllocationBlock(begin_idx_in_loop + 1 * loop_size_,
end_idx_in_loop + 1 * loop_size_, size);
CreateBufferInterval(second_allocation_block, &first_allocation_block);
const AllocationBlock& third_allocation_block =
GetAllocationBlock(begin_idx_in_loop + 2 * loop_size_,
end_idx_in_loop + 2 * loop_size_, size);
CreateBufferInterval(third_allocation_block, &first_allocation_block);
const AllocationBlock& fourth_allocation_block =
GetAllocationBlock(begin_idx_in_loop + 3 * loop_size_,
end_idx_in_loop + 3 * loop_size_, size);
CreateBufferInterval(fourth_allocation_block, &first_allocation_block);
return first_allocation_block;
}

EvenOddChunkPair LoopOptimizerBestFitHeap::FindSameEvenAndOddAllocationBetween(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop, int64_t size,
int64_t preferred_offset) {
CheckAllocationIntervalValid(begin_idx_in_loop, end_idx_in_loop);
ShiftAllocationIntervalIfRequired(begin_idx_in_loop, end_idx_in_loop);
// An allocation that is colocated in even and odd iterations cannot be double
// buffered i.e. it should span less than or equal to one loop iteration).
CHECK_LE(end_idx_in_loop - begin_idx_in_loop + 1, loop_size_);
const AllocationBlock& allocation = CreateSameEvenAndOddAllocationBlock(
begin_idx_in_loop, end_idx_in_loop, size);
std::optional<HeapSimulator::Chunk> chunk =
MaybeFindChunkCandidate(allocation, preferred_offset);
return {chunk, chunk};
}

EvenOddChunkPair LoopOptimizerBestFitHeap::AllocateSameEvenAndOddBetween(
int64_t begin_idx_in_loop, int64_t end_idx_in_loop, int64_t size,
int64_t preferred_offset) {
CheckAllocationIntervalValid(begin_idx_in_loop, end_idx_in_loop);
ShiftAllocationIntervalIfRequired(begin_idx_in_loop, end_idx_in_loop);
// An allocation that is colocated in even and odd iterations cannot be double
// buffered i.e. it should span less than or equal to one loop iteration).
CHECK_LE(end_idx_in_loop - begin_idx_in_loop + 1, loop_size_);
const AllocationBlock& allocation = CreateSameEvenAndOddAllocationBlock(
begin_idx_in_loop, end_idx_in_loop, size);
std::optional<HeapSimulator::Chunk> chunk =
FindAndCommitChunkCandidate(allocation, preferred_offset);
return {chunk, chunk};
}

std::string LoopOptimizerBestFitHeap::MemoryUsageToAsciiArt(
int64_t begin_iteration, int64_t end_iteration) const {
CHECK_LE(0, begin_iteration);
CHECK_LE(begin_iteration, end_iteration);
return interval_tree_.NodesOverlappingInTimeToAsciiArt(
loop_size_ * begin_iteration, loop_size_ * (end_iteration + 1) - 1,
loop_size_);
}

std::vector<int64_t> LoopOptimizerBestFitHeap::RemainingMemoryByTime() const {
// Only 2nd and 3rd iterations have the correct (and identical) memory usage.
// 1st and 4th iterations serve only to model the boundary conditions.
std::vector<int64_t> memory_used_by_time =
interval_tree_.MemoryUsedInInterval(loop_size_ * 2, loop_size_ * 3 - 1);
std::vector<int64_t> remaining_memory_by_time(loop_size_);
for (int i = 0; i < loop_size_; ++i) {
remaining_memory_by_time[i] = size_limit_per_heap_ - memory_used_by_time[i];
}
return remaining_memory_by_time;
}

int64_t LoopOptimizerBestFitHeap::LastMemoryOffsetOccupied() const {
// 2nd and 3rd iterations will suffice for getting the current alternate
// memory size.
return interval_tree_.HeapSizeInInterval(loop_size_ * 2, loop_size_ * 4 - 1);
}

/*static*/ absl::StatusOr<std::unique_ptr<MemoryBoundLoopOptimizer>>
MemoryBoundLoopOptimizer::Create(
int loop_start, int loop_end, uint64_t alternate_memory_size,
Expand Down
Loading

0 comments on commit 30bf2dd

Please sign in to comment.