Skip to content

Commit

Permalink
Improve IPU PjRt client asynchronous dispatch performance. (#48)
Browse files Browse the repository at this point in the history
In the IPU PjRt client, `Poplar::Engine::run` is being called in a
dedicated compute thread to minimize overhead. This PR is improving the
performance of this asynchronous dispatch loop by moving the cleanup
phase (i.e. deleted large state data structure) to a dedicated thread.

This removes 40-50us of overhead in the compute thread.
  • Loading branch information
balancap authored Sep 19, 2023
1 parent 2e00430 commit 7033152
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -593,18 +593,18 @@ IpuPjRtRunState& IpuPjRtRunState::operator=(IpuPjRtRunState&& rhs) noexcept {
return *this;
}

StatusOr<IpuPjRtRunState> IpuPjRtRunState::CreateWithIOBuffers(
StatusOr<std::unique_ptr<IpuPjRtRunState>> IpuPjRtRunState::CreateWithIOBuffers(
tfrt::AsyncValueRef<CpuEvent> execute_event,
absl::Span<const std::vector<PjRtBuffer*>> all_input_handles,
const IpuPjRtInputOutputAliasing& input_output_aliasing,
const std::vector<InputOutputAliasingMap::OutputInfo>& output_infos,
xla::TransposePlanCache& transpose_cache) {
const auto num_replicas = all_input_handles.size();
IpuPjRtRunState run_state;
auto run_state = std::make_unique<IpuPjRtRunState>();
// Copy the execute event into the state.
run_state.execute_event = std::move(execute_event);
run_state.all_inputs.reserve(num_replicas);
run_state.all_outputs.reserve(num_replicas);
run_state->execute_event = std::move(execute_event);
run_state->all_inputs.reserve(num_replicas);
run_state->all_outputs.reserve(num_replicas);
for (std::size_t replica = 0; replica < num_replicas; ++replica) {
// Input buffers for the replica.
TF_ASSIGN_OR_RETURN(auto inputs,
Expand All @@ -614,10 +614,10 @@ StatusOr<IpuPjRtRunState> IpuPjRtRunState::CreateWithIOBuffers(
// Raw output buffers for the replica.
TF_ASSIGN_OR_RETURN(auto outputs,
IpuPjRtRunReplicaOutputs::AllocateFromOutputInfos(
run_state.execute_event, inputs,
run_state->execute_event, inputs,
input_output_aliasing, output_infos));
run_state.all_inputs.push_back(std::move(inputs));
run_state.all_outputs.push_back(std::move(outputs));
run_state->all_inputs.push_back(std::move(inputs));
run_state->all_outputs.push_back(std::move(outputs));
}
return run_state;
}
Expand Down Expand Up @@ -905,9 +905,11 @@ IpuPjRtExecutable::IpuPjRtExecutable(
LOG(INFO) << "IPU PjRt executable input/output aliasing map: "
<< m_input_output_aliasing.ToString();

// Start execute thread in asynchronous case.
// Start execute + cleanup threads in asynchronous case.
if (m_asynchronous_run) {
m_execute_thread = std::thread(&IpuPjRtExecutable::ExecuteAsyncLoop, this);
m_cleanup_thread =
std::thread(&IpuPjRtExecutable::ExecuteCleanupAsyncLoop, this);
}
}
IpuPjRtExecutable::~IpuPjRtExecutable() {
Expand Down Expand Up @@ -1025,13 +1027,13 @@ IpuPjRtExecutable::Execute(
IpuPjRtRunState::CreateWithIOBuffers(
execute_event, argument_handles, m_input_output_aliasing,
io_aliasing_map.GetEntryOutputInfos(), m_host_transpose_cache));
run_state.inputs_donated_location = inputs_donated_location;
run_state->inputs_donated_location = inputs_donated_location;

// Wrapping execute even as PjRt future status.
auto done_event = tfrt::MakeUnconstructedAsyncValueRef<Status>();
run_state.execute_event.AndThen(
run_state->execute_event.AndThen(
[done_event = done_event.CopyRef(),
event = run_state.execute_event.CopyRef()]() {
event = run_state->execute_event.CopyRef()]() {
Status s;
if (auto* error = event.GetErrorIfPresent()) {
s = InternalError("Compute error: %s", error->message);
Expand All @@ -1041,13 +1043,13 @@ IpuPjRtExecutable::Execute(
auto future_status = PjRtFuture<Status>(std::move(done_event));

// Returned outputs, for all replica, with IPU executable reference.
auto outputs = run_state.CreateOutputIpuPjRtBuffers(
run_state.execute_event, this->input_output_aliasing(),
auto outputs = run_state->CreateOutputIpuPjRtBuffers(
run_state->execute_event, this->input_output_aliasing(),
io_aliasing_map.GetEntryOutputInfos(), m_devices, this);
// Returned futures, for all replica.
if (returned_futures) {
returned_futures.value().clear();
for (std::size_t idx = 0; idx < run_state.num_replicas(); ++idx) {
for (std::size_t idx = 0; idx < run_state->num_replicas(); ++idx) {
returned_futures.value().push_back(future_status);
}
}
Expand All @@ -1068,15 +1070,15 @@ IpuPjRtExecutable::Execute(
// NOTE: passing inputs donated location, as used for marking previous SRAM
// (UNCHANGED or UPDATED) buffers as expired (or not).
auto [update_run_info, update_mesh_transition] = m_client->UpdateClientState(
m_device_mesh_id, m_executable_id, run_state.inputs_donated_location,
m_last_run_outputs_ref, run_state.execute_event.CopyRef());
m_device_mesh_id, m_executable_id, run_state->inputs_donated_location,
m_last_run_outputs_ref, run_state->execute_event.CopyRef());
// Move run info and mesh transition info.
run_state.run_info = std::move(update_run_info);
run_state.mesh_transition = std::move(update_mesh_transition);
run_state->run_info = std::move(update_run_info);
run_state->mesh_transition = std::move(update_mesh_transition);

// No need for inputs scoped hold => convert to usage event.
// NOTE: necessary, otherwise scoped hold blocking buffer delete.
run_state.ConvertInputBufferHold();
run_state->ConvertInputBufferHold();
// Mark unchanged donated input buffers as SRAM synchronized.
// Allows to write nice JAX inference code, where the donated output is
// ignored and weights are on SRAM for next iteration.
Expand All @@ -1087,19 +1089,19 @@ IpuPjRtExecutable::Execute(
<< " num_partitions=" << num_partitions()
<< "; num_addressable_devices=" << num_addressable_devices
<< "; num_inputs=" << num_inputs << " num_outputs=" << num_outputs
<< "; executable id: " << run_state.run_info.executable_id
<< "; run id: " << run_state.run_info.run_id
<< "; executable id: " << run_state->run_info.executable_id
<< "; run id: " << run_state->run_info.run_id
<< "; mesh_id: " << m_device_mesh_id;

/////////////// RUNNING POPLAR ENGINE ///////////////
/////////////// RUNNING POPLAR ENGINE ///////////////
/////////////// RUNNING POPLAR ENGINE ///////////////
if (m_asynchronous_run) {
// Queue the state, for async. Poplar engine run.
m_execute_queue.enqueue(std::move(run_state));
m_execute_run_state_queue.enqueue(std::move(run_state));
} else {
// Synchronous Poplar engine run.
this->ExecuteDeviceRun(run_state);
this->ExecuteDeviceRun(*run_state);
}
return outputs;
}
Expand Down Expand Up @@ -1169,17 +1171,34 @@ void IpuPjRtExecutable::ExecuteDeviceRun(IpuPjRtRunState& run_state) {
}

void IpuPjRtExecutable::ExecuteAsyncLoop() {
while (!m_executable_is_deleted.load()) {
while (!this->IsDeleted()) {
// Wait for the next run state.
auto run_state = m_execute_queue.dequeue();
// Ignore empty/dummy run state.
if (run_state.empty()) {
auto run_state = m_execute_run_state_queue.dequeue();
// Ignore empty/nulltpr run state => break loop.
if (!run_state) {
continue;
}
// Blocking Poplar engine run.
this->ExecuteDeviceRun(run_state);
// Synchronous/blocking Poplar engine run.
this->ExecuteDeviceRun(*run_state);
// Mark the execute event as done!
run_state.execute_event.SetStateConcrete();
run_state->execute_event.SetStateConcrete();
// Move the run state to the cleanup queue.
// Why?! deleting the run state is non negligible overhead in the loop!
m_clean_run_state_queue.enqueue(std::move(run_state));
}
}

void IpuPjRtExecutable::ExecuteCleanupAsyncLoop() {
while (!this->IsDeleted()) {
// Wait for the next state to cleanup!
auto run_state = m_clean_run_state_queue.dequeue();
TENSORFLOW_TRACEPOINT();
// Ignore empty/nulltpr run state => break loop.
if (!run_state) {
continue;
}
// Explicit reset (not really necessary!)
run_state.reset();
}
}

Expand Down Expand Up @@ -1223,9 +1242,11 @@ IpuPjRtExecutable::ExecutePortable(
void IpuPjRtExecutable::Delete() {
m_executable_is_deleted.store(true);
if (m_asynchronous_run) {
// Queue empty run, just to unblock thread.
m_execute_queue.enqueue(IpuPjRtRunState());
// Enqueue nullptr, to unblock threads.
m_execute_run_state_queue.enqueue(nullptr);
m_clean_run_state_queue.enqueue(nullptr);
m_execute_thread.join();
m_cleanup_thread.join();
}
// Mark last run on-device expired, and nullify ptr.
// Avoid error if trying to fetch on-device values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ struct IpuPjRtRunState {
* @param transpose_cache Host transpose cache.
* @return IPU run state with proper IO buffers.
*/
static StatusOr<IpuPjRtRunState> CreateWithIOBuffers(
static StatusOr<std::unique_ptr<IpuPjRtRunState>> CreateWithIOBuffers(
tfrt::AsyncValueRef<CpuEvent> execute_event,
absl::Span<const std::vector<PjRtBuffer*>> all_input_handles,
const IpuPjRtInputOutputAliasing& input_output_aliasing,
Expand Down Expand Up @@ -485,10 +485,14 @@ class IpuPjRtExecutable : public PjRtExecutable {
const ExecuteOptions& options,
std::optional<std::vector<PjRtFuture<Status>>>& returned_futures);

/** Execute loop fucntion, used in the asynchronous case.
/** Execute loop function, used in the asynchronous case.
* This method is run in a separate execute thread.
*/
void ExecuteAsyncLoop();
/** Cleanup asynchronous loop => avoid blocking main execution
* thread for just destroying objects!
*/
void ExecuteCleanupAsyncLoop();

/** Asynchronous execution on IPU? */
bool m_asynchronous_run = false;
Expand Down Expand Up @@ -518,8 +522,14 @@ class IpuPjRtExecutable : public PjRtExecutable {

/** Asynchronous execute thread. */
std::thread m_execute_thread;
/** Asynchronous execute queue. */
ThreadSafeQueue<IpuPjRtRunState> m_execute_queue;
/** Asynchronous cleanup thread. */
std::thread m_cleanup_thread;

/** Asynchronous execute run state queue. */
ThreadSafeQueue<std::unique_ptr<IpuPjRtRunState>> m_execute_run_state_queue;
/** Asynchronous execute run state cleanup queue. */
ThreadSafeQueue<std::unique_ptr<IpuPjRtRunState>> m_clean_run_state_queue;

/** Executable delete status. */
std::atomic_bool m_executable_is_deleted{false};

Expand Down

0 comments on commit 7033152

Please sign in to comment.