Skip to content

Commit

Permalink
Fix IPU PjRt client output stream connection error. (#44)
Browse files Browse the repository at this point in the history
When using IPU XLA custom op with host callbacks, the new IPU PjRt
client was raising Poplar engine stream errors. Two issues are fixed
here:

* Ignore missing Poplar output streams, just issuing a warning;
* Properly connect Poplar host callbacks;
  • Loading branch information
balancap authored Sep 11, 2023
1 parent d79af3e commit 8d4944b
Showing 1 changed file with 64 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,53 @@ StatusOr<IpuPjRtBufferLocation> CheckInputDonatedBuffers(
"donated inputs.");
}

/** @brief Get Poplar engine input stream names. */
std::unordered_set<std::string> GetEngineInputStreams(poplar::Engine* engine) {
std::unordered_set<std::string> input_streams;
const auto all_streams = engine->listStreams();
for (const auto& v : all_streams) {
if (v.back() == '+') {
input_streams.insert(v.substr(0, v.size() - 1));
}
}
return input_streams;
}
/** @brief Get Poplar engine output stream names. */
std::unordered_set<std::string> GetEngineOutputStreams(poplar::Engine* engine) {
std::unordered_set<std::string> output_streams;
const auto all_streams = engine->listStreams();
for (const auto& v : all_streams) {
if (v.back() == '-') {
output_streams.insert(v.substr(0, v.size() - 1));
}
}
return output_streams;
}

/** @brief Connecting a host callback function to a Poplar engine. */
void ConnectStreamHostCallback(const HostFunctionInfo& host_function_info,
poplar::Engine* engine) {
std::vector<const void*> local_input_state(
host_function_info.input_shapes.size());
std::vector<void*> local_output_state(
host_function_info.output_shapes.size());
auto host_fn = [local_input_state, local_output_state, host_function_info](
poplar::ArrayRef<const void*> ins,
poplar::ArrayRef<void*> outs) mutable {
absl::c_copy(ins, local_input_state.begin());
absl::c_copy(outs, local_output_state.begin());
host_function_info.function(local_input_state, local_output_state);
};
// Replica 0 only supported at the moment.
engine->connectHostFunction(host_function_info.handle, 0, host_fn);
}
void ConnectStreamHostCallbacks(const HostFunctionInfos& host_function_infos,
poplar::Engine* engine) {
for (const auto& host_function_info : host_function_infos) {
ConnectStreamHostCallback(host_function_info.second, engine);
}
}

} // namespace

/**
Expand Down Expand Up @@ -224,6 +271,7 @@ Status CheckPoplarExecutableValid(PoplarExecutable* poplar_executable,
CHECK_EQ(poplar_executable->GetStreamInfos().size(), 0);
CHECK_EQ(poplar_executable->GetSendInfos().size(), 0);
CHECK_EQ(poplar_executable->GetRecvInfos().size(), 0);
CHECK_EQ(poplar_executable->GetRemoteParameterInfos().size(), 0);

// Consistency of compile options.
CHECK(compile_options.executable_build_options.has_device_assignment());
Expand Down Expand Up @@ -427,16 +475,24 @@ void IpuPjRtRunReplicaOutputs::ConnectStreamCallbacks(
const std::vector<InputOutputAliasingMap::OutputInfo>& output_infos,
int replica, poplar::Engine* engine) {
const auto num_outputs = output_infos.size();
const auto engine_output_streams = GetEngineOutputStreams(engine);
for (std::size_t i = 0; i < num_outputs; ++i) {
const auto& outinfo = output_infos[i];
// Connect only streamed outputs.
if (outinfo.IsStreaming()) {
const auto& outname = outinfo.Handles()[0];
// TODO: support tuples properly.
engine->connectStreamToCallback(
outname, replica,
std::make_unique<IpuOutputStreamCallback>(
host_tracked_buffers[i]->Buffers()[0]));
if (engine_output_streams.count(outname) == 0) {
LOG_FIRST_N(WARNING, 1) << absl::StrFormat(
"Ignoring XLA streaming output '%s'. No associated Poplar output "
"stream.",
outname);
} else {
// TODO: support tuples properly.
engine->connectStreamToCallback(
outname, replica,
std::make_unique<IpuOutputStreamCallback>(
host_tracked_buffers[i]->Buffers()[0]));
}
}
}
}
Expand Down Expand Up @@ -1024,6 +1080,9 @@ void IpuPjRtExecutable::ExecuteDeviceRun(IpuPjRtRunState& run_state) {
run_state.ConnectStreamCallbacks(io_aliasing_map.GetEntryInputInfos(),
io_aliasing_map.GetEntryOutputInfos(),
engine);
// Connect engine streams corresponding to host callbacks.
ConnectStreamHostCallbacks(poplar_executable->GetHostFunctionInfos(), engine);

// Synchronous call => blocking thread!
LOG(INFO) << "Run IPU poplar engine " << name()
<< "; executable id: " << run_state.run_info.executable_id
Expand Down

0 comments on commit 8d4944b

Please sign in to comment.