Skip to content

Commit

Permalink
[XLA:GPU] Share common HLO computations between some of the pipeline …
Browse files Browse the repository at this point in the history
…tests

PiperOrigin-RevId: 658441952
  • Loading branch information
frgossen authored and copybara-github committed Aug 1, 2024
1 parent 336c9a9 commit 0609842
Showing 1 changed file with 109 additions and 162 deletions.
271 changes: 109 additions & 162 deletions xla/tests/collective_pipeline_parallelism_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -546,6 +547,59 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
ErrorSpec{1e-5, 1e-5}));
}

std::string GetModuleStrWithCommonComputations(
const std::string name, const std::string more_computations) {
static constexpr char kCommonComputationsStr[] = R"(
read_buffer_mb5 {
buffer = f32[5,16] parameter(0)
offset = u32[] parameter(1)
index = u32[] parameter(2)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
slice = f32[1,16] dynamic-slice(buffer, index__, c0),
dynamic_slice_sizes={1,16}
ROOT slice_ = f32[16] reshape(slice)
}
update_buffer_mb5 {
buffer = f32[5,16] parameter(0)
update = f32[16] parameter(1)
offset = u32[] parameter(2)
index = u32[] parameter(3)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
update_ = f32[1,16] reshape(update)
ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
}
is_input_replica {
replica_id = u32[] replica-id()
c0 = u32[] constant(0)
ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
}
is_output_replica {
replica_id = u32[] replica-id()
c3 = u32[] constant(3)
ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
}
is_read_input_mb5 {
is_input_replica = pred[] call(), to_apply=is_input_replica
i = u32[] parameter(0)
c5 = u32[] constant(5)
is_input_iteration = pred[] compare(i, c5), direction=LT
ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
}
)";
return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" +
more_computations;
}

// Naive implementation if pipeline parallelism:
// - 4 devices
// - 5 microbatches
Expand All @@ -556,65 +610,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
// Every stage of the pipeline is a single linear layer.
XLA_TEST_F(CollectivePipelineParallelismTest,
NaiveDFSMicrobatch5CircularRepeat2Replica4) {
const absl::string_view kModuleStr = R"(
HloModule test
get_circ_buffer_index {
offset = u32[] parameter(0)
index = u32[] parameter(1)
size = u32[] parameter(2)
t0 = u32[] add(offset, index)
t1 = u32[] divide(t0, size)
t2 = u32[] multiply(t1, size)
ROOT t4 = u32[] subtract(t0, t2)
}
read_buffer {
buffer = f32[5,16] parameter(0)
offset = u32[] parameter(1)
index = u32[] parameter(2)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
slice = f32[1,16] dynamic-slice(buffer, index__, c0),
dynamic_slice_sizes={1,16}
ROOT slice_ = f32[16] reshape(slice)
}
update_buffer {
buffer = f32[5,16] parameter(0)
update = f32[16] parameter(1)
offset = u32[] parameter(2)
index = u32[] parameter(3)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
update_ = f32[1,16] reshape(update)
ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
}
is_input_replica {
replica_id = u32[] replica-id()
c0 = u32[] constant(0)
ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
}
is_output_replica {
replica_id = u32[] replica-id()
c3 = u32[] constant(3)
ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
}
is_read_input {
is_input_replica = pred[] call(), to_apply=is_input_replica
i = u32[] parameter(0)
c5 = u32[] constant(5)
is_input_iteration = pred[] compare(i, c5), direction=LT
ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
}
constexpr char kMoreComputationsStr[] = R"(
while_condition {
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
parameter(0)
Expand All @@ -630,43 +626,46 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
input = f32[5,16] get-tuple-element(tuple), index=1
output = f32[5,16] get-tuple-element(tuple), index=2
buffer = f32[5,16] get-tuple-element(tuple), index=3
prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4
prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4
i = u32[] get-tuple-element(tuple), index=5
c0 = u32[] constant(0)
c1 = u32[] constant(1)
c2 = u32[] constant(2)
c3 = u32[] constant(3)
c4 = u32[] constant(4)
c5 = u32[] constant(5)
input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index
input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
dynamic_slice_sizes={1,16}
input_slice_ = f32[16] reshape(input_slice)
buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer
// Read from buffers.
input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5
buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5
// Shift data to the next stage in the pipeline.
// Directly depends on the updated buffer of the previous iteration and,
// therefore, depends on the previous iteration's compute.
is_output_replica = pred[] call(), to_apply=is_output_replica
next_stage_slice = select(is_output_replica, buffer_slice,
prev_iteration_compute_out)
prev_iteration_compute_res)
prev_stage_slice = f32[16] collective-permute(next_stage_slice),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
is_read_input = pred[] call(i), to_apply=is_read_input
compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice)
compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
// Select compute argument from previous stage or from input and perform
// compute.
is_read_input = pred[] call(i), to_apply=is_read_input_mb5
compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice)
compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
rhs_contracting_dims={0}
output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer
buffer_ = f32[5,16] call(buffer, compute_out, c0, i), to_apply=update_buffer
// Update buffers.
output_ = f32[5,16] call(output, compute_res, c2, i),
to_apply=update_buffer_mb5
buffer_ = f32[5,16] call(buffer, compute_res, c0, i),
to_apply=update_buffer_mb5
i_ = add(i, c1)
ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
tuple(weights, input, output_, buffer_, compute_out, i_)
tuple(weights, input, output_, buffer_, compute_res, i_)
}
ENTRY main {
Expand All @@ -676,11 +675,12 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
cf0 = f32[] constant(0)
output = f32[5,16] broadcast(cf0), dimensions={}
buffer = f32[5,16] broadcast(cf0), dimensions={}
prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={}
prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
c0 = u32[] constant(0)
// Iterate through pipeline stages.
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
tuple(weights, input, output, buffer, prev_iteration_compute_out, c0)
tuple(weights, input, output, buffer, prev_iteration_compute_res, c0)
tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
while(tuple), condition=while_condition, body=while_body
Expand All @@ -693,8 +693,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest,

HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
/*name=*/"test", kMoreComputationsStr),
config));

// This pipeline consists of a total of 8 layers (2 per replica), each of
// which is a single linear layer. We assign the weights to the replicas such
Expand Down Expand Up @@ -745,65 +748,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
// Every stage of the pipeline is a single linear layer.
XLA_TEST_F(CollectivePipelineParallelismTest,
NaiveWoDirectBufferDependencyDFSMicrobatch5CircularRepeat2Replica4) {
const absl::string_view kModuleStr = R"(
HloModule test
get_circ_buffer_index {
offset = u32[] parameter(0)
index = u32[] parameter(1)
size = u32[] parameter(2)
t0 = u32[] add(offset, index)
t1 = u32[] divide(t0, size)
t2 = u32[] multiply(t1, size)
ROOT t4 = u32[] subtract(t0, t2)
}
read_buffer {
buffer = f32[5,16] parameter(0)
offset = u32[] parameter(1)
index = u32[] parameter(2)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
slice = f32[1,16] dynamic-slice(buffer, index__, c0),
dynamic_slice_sizes={1,16}
ROOT slice_ = f32[16] reshape(slice)
}
update_buffer {
buffer = f32[5,16] parameter(0)
update = f32[16] parameter(1)
offset = u32[] parameter(2)
index = u32[] parameter(3)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
update_ = f32[1,16] reshape(update)
ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
}
is_input_replica {
replica_id = u32[] replica-id()
c0 = u32[] constant(0)
ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
}
is_output_replica {
replica_id = u32[] replica-id()
c3 = u32[] constant(3)
ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
}
is_read_input {
is_input_replica = pred[] call(), to_apply=is_input_replica
i = u32[] parameter(0)
c5 = u32[] constant(5)
is_input_iteration = pred[] compare(i, c5), direction=LT
ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
}
constexpr char kMoreComputationsStr[] = R"(
while_condition {
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
parameter(0)
Expand All @@ -819,7 +764,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
input = f32[5,16] get-tuple-element(tuple), index=1
output = f32[5,16] get-tuple-element(tuple), index=2
buffer = f32[5,16] get-tuple-element(tuple), index=3
prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4
prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4
i = u32[] get-tuple-element(tuple), index=5
c0 = u32[] constant(0)
Expand All @@ -829,38 +774,36 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
c4 = u32[] constant(4)
c5 = u32[] constant(5)
input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index
input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
dynamic_slice_sizes={1,16}
input_slice_ = f32[16] reshape(input_slice)
buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer
buffer_ = f32[5,16] call(buffer, prev_iteration_compute_out, c4, i),
to_apply=update_buffer
// Read from buffers before they are updated.
input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5
buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5
// Shift data to the next stage in the pipeline.
// Depends on the non-updated buffer of the previous iteration and,
// therefore, does not depend on the previous iteration's compute.
is_output_replica = pred[] call(), to_apply=is_output_replica
next_stage_slice = select(is_output_replica, buffer_slice,
prev_iteration_compute_out)
prev_iteration_compute_res)
prev_stage_slice = f32[16] collective-permute(next_stage_slice),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
is_read_input = pred[] call(i), to_apply=is_read_input
compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice)
compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
// Select compute argument from previous stage or from input and perform
// compute.
is_read_input = pred[] call(i), to_apply=is_read_input_mb5
compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice)
compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
rhs_contracting_dims={0}
output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer
// Update buffers.
buffer_ = f32[5,16] call(buffer, prev_iteration_compute_res, c4, i),
to_apply=update_buffer_mb5
output_ = f32[5,16] call(output, compute_res, c2, i),
to_apply=update_buffer_mb5
i_ = add(i, c1)
ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
tuple(weights, input, output_, buffer_, compute_out, i_)
tuple(weights, input, output_, buffer_, compute_res, i_)
}
ENTRY main {
Expand All @@ -870,11 +813,12 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
cf0 = f32[] constant(0)
output = f32[5,16] broadcast(cf0), dimensions={}
buffer = f32[5,16] broadcast(cf0), dimensions={}
prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={}
prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
c0 = u32[] constant(0)
// Iterate through pipeline stages.
tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
tuple(weights, input, output, buffer, prev_iteration_compute_out, c0)
tuple(weights, input, output, buffer, prev_iteration_compute_res, c0)
tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[])
while(tuple), condition=while_condition, body=while_body
Expand All @@ -887,8 +831,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest,

HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
/*name=*/"test", kMoreComputationsStr),
config));

// This pipeline consists of a total of 8 layers (2 per replica), each of
// which is a single linear layer. We assign the weights to the replicas such
Expand Down

0 comments on commit 0609842

Please sign in to comment.