From 6673589701d5bfcb0d03c7f43836988948b6985c Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 17 Sep 2024 16:44:03 -0700 Subject: [PATCH] Memory space related copies should not be normalized. PiperOrigin-RevId: 675755862 --- xla/service/host_offload_utils.cc | 12 ++++++++++++ xla/service/host_offload_utils.h | 3 +++ 2 files changed, 15 insertions(+) diff --git a/xla/service/host_offload_utils.cc b/xla/service/host_offload_utils.cc index c306a4d100435..98732043da3a6 100644 --- a/xla/service/host_offload_utils.cc +++ b/xla/service/host_offload_utils.cc @@ -244,5 +244,17 @@ bool IsHostAsyncStart(const HloInstruction* instruction) { instruction->async_execution_thread() == HloInstruction::kHostThread; } +bool IsSynchronousCopyFromOrToHost(const HloInstruction* instruction) { + if (instruction->opcode() != HloOpcode::kCopy) { + return false; + } + return (instruction->shape().has_layout() && + instruction->shape().layout().memory_space() == + Layout::kHostMemorySpace) || + (instruction->operand(0)->shape().has_layout() && + instruction->operand(0)->shape().layout().memory_space() == + Layout::kHostMemorySpace); +} + } // namespace host_offload_utils } // namespace xla diff --git a/xla/service/host_offload_utils.h b/xla/service/host_offload_utils.h index 6546beb3a5285..c71dcc21e4ea8 100644 --- a/xla/service/host_offload_utils.h +++ b/xla/service/host_offload_utils.h @@ -98,6 +98,9 @@ bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction); // Returns true if the instruction is an async-start with host thread. bool IsHostAsyncStart(const HloInstruction* instruction); +// Returns true if the copy is from or to host memory space. +bool IsSynchronousCopyFromOrToHost(const HloInstruction* instruction); + } // namespace host_offload_utils } // namespace xla