Skip to content

Commit

Permalink
Allow using host layout for argument when allocating buffers.
Browse files Browse the repository at this point in the history
Extends multi_host_runner's running_option to allow using the layout in the host literal for argument when copying arguments to device.

PiperOrigin-RevId: 681672653
  • Loading branch information
Google-ML-Automation committed Oct 3, 2024
1 parent e151a58 commit 0adc6d0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
42 changes: 22 additions & 20 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -822,14 +822,13 @@ FunctionalHloRunner::Run(PjRtClient& client, PjRtLoadedExecutable* executable,
flattened_arguments.insert({device_id, std::move(flattened_argument)});
}
return CopyArgumentsToDevice(client, executable, flattened_arguments,
running_options.log_input_output(),
running_options,
/*flattened_arguments=*/true);
}
// If the per-device argument is not a single tuple, we ignore the
// flatten_tupled_arguments parameter and assume the provided arguments have
// already been flattened.
return CopyArgumentsToDevice(client, executable, arguments,
running_options.log_input_output(),
return CopyArgumentsToDevice(client, executable, arguments, running_options,
/*flattened_arguments=*/false);
};
return RunInternal(client, executable, create_argument_buffers_on_device,
Expand Down Expand Up @@ -1164,14 +1163,13 @@ FunctionalHloRunner::CreateArgumentsOnDevice(
}

if (kUseSharedInputs) {
return CopyArgumentsToDevice(
client, executable, per_device_argument_literals,
running_options.log_input_output(), flatten_arguments,
/*clone_device0_arguments=*/true);
return CopyArgumentsToDevice(client, executable,
per_device_argument_literals, running_options,
flatten_arguments,
/*clone_device0_arguments=*/true);
}
return CopyArgumentsToDevice(client, executable, per_device_argument_literals,
running_options.log_input_output(),
flatten_arguments);
running_options, flatten_arguments);
}

absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Expand Down Expand Up @@ -1261,8 +1259,10 @@ FunctionalHloRunner::CreateUninitializedArgumentsOnDevice(
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
FunctionalHloRunner::CopyArgumentsToDevice(
PjRtClient& client, const PjRtLoadedExecutable* executable,
const PerDeviceLiteralVecType& arguments, bool log_input,
bool flattened_arguments, bool clone_device0_arguments) {
const PerDeviceLiteralVecType& arguments,
const RunningOptions& running_options, bool flattened_arguments,
bool clone_device0_arguments) {
const bool log_input = running_options.log_input_output();
absl::Span<PjRtDevice* const> addressable_devices =
executable->addressable_devices();
size_t num_addressable_devices = addressable_devices.size();
Expand Down Expand Up @@ -1301,20 +1301,22 @@ FunctionalHloRunner::CopyArgumentsToDevice(
TF_RET_CHECK(!shape.IsTuple()) << "Param tuple without flattened_arguments";
return non_tuple_memory_space(shape);
};
auto buffer_from_host_literal = [&client, &argument_memory_space](
const HloModule* module,
PjRtDevice* device, int arg_i,
const Literal& literal)
auto buffer_from_host_literal =
[&client, &argument_memory_space, &running_options](
const HloModule* module, PjRtDevice* device, int arg_i,
const Literal& literal)
-> absl::StatusOr<std::unique_ptr<PjRtBuffer>> {
const Layout* layout = nullptr;
if (running_options.use_argument_host_layout &&
literal.shape().has_layout()) {
layout = &literal.shape().layout();
}
if (client.memory_spaces().empty()) {
return client.BufferFromHostLiteral(
literal, device,
literal.shape().has_layout() ? &literal.shape().layout() : nullptr);
return client.BufferFromHostLiteral(literal, device, layout);
}
TF_ASSIGN_OR_RETURN(PjRtMemorySpace * memory_space,
argument_memory_space(module, device, arg_i));
return client.BufferFromHostLiteral(literal, memory_space,
/* device_layout */ nullptr);
return client.BufferFromHostLiteral(literal, memory_space, layout);
};

absl::Span<const PjRtLoadedExecutable::LogicalDeviceIds>
Expand Down
5 changes: 4 additions & 1 deletion xla/tools/multihost_hlo_runner/functional_hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ class FunctionalHloRunner {
// Whether to untuple the result of running HLO module into a vector of
// arrays. If unprovided, use the default in ExecuteOptions.
std::optional<bool> untuple_result = std::nullopt;
// Whether to use the layout on host when allocating buffers for arguments.
// Some platforms (e.g. CPU) do not support this yet.
bool use_argument_host_layout = false;

// Should we log the inputs and outputs to stderr?
bool log_input_output() const {
Expand Down Expand Up @@ -377,7 +380,7 @@ class FunctionalHloRunner {
CopyArgumentsToDevice(PjRtClient& client,
const PjRtLoadedExecutable* executable,
const PerDeviceLiteralVecType& arguments,
bool log_input, bool flattened_arguments,
const RunningOptions& options, bool flattened_arguments,
bool clone_device0_arguments = false);

static absl::StatusOr<PerDeviceLiteralVecType> RunInternal(
Expand Down

0 comments on commit 0adc6d0

Please sign in to comment.