Skip to content

Commit

Permalink
fix minor issues
Browse files Browse the repository at this point in the history
  • Loading branch information
terryysun committed Nov 19, 2024
1 parent d3657f8 commit c9202fa
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 24 deletions.
16 changes: 9 additions & 7 deletions xla/hlo/builder/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4093,9 +4093,9 @@ XlaOp XlaBuilder::CollectivePermuteImpl(
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferCollectivePermuteShape(
{{operand_shape}}, inplace));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferCollectivePermuteShape(operand_shape, inplace));
*instr.mutable_shape() = shape.ToProto();

for (const auto& pair : source_target_pairs) {
Expand Down Expand Up @@ -4127,11 +4127,13 @@ XlaOp XlaBuilder::CollectivePermuteImpl(
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
operand_shapes.push_back(operand_shape);
}
CHECK(operand_shapes.size() > 1);
CHECK_GT(operand_shapes.size(), 1);
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferCollectivePermuteShape(operand_shapes, inplace));
auto tuple_operand_shapes =
ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferCollectivePermuteShape(
&tuple_operand_shapes, inplace));
*instr.mutable_shape() =
ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes).ToProto();

Expand Down
2 changes: 1 addition & 1 deletion xla/python/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,6 @@ void BuildOpsSubmodule(nb::module_& m) {
UNARY_OP(Conj);
UNARY_OP(OptimizationBarrier);
#undef UNARY_OP
}
} // NOLINT(readability/fn_size)

} // namespace xla
3 changes: 2 additions & 1 deletion xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,10 @@ absl::Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
absl::c_transform(
hlo->operands(), std::back_inserter(operand_shapes),
[](const HloInstruction* operand) { return &(operand->shape()); });
auto tuple_operand_shapes = ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
return CheckShape(hlo,
ShapeInference::InferCollectivePermuteShape(
operand_shapes,
&tuple_operand_shapes,
Cast<HloCollectivePermuteInstruction>(hlo)->inplace()));
}

Expand Down
24 changes: 11 additions & 13 deletions xla/service/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2723,7 +2723,16 @@ ShapeInference::InferCollectiveBroadcastShape(
}

/* static */ absl::StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
absl::Span<const Shape* const> operand_shapes, bool inplace) {
const Shape* const operand_shape, bool inplace) {
std::vector<const Shape*> operand_shapes;
if (!operand_shape->IsTuple()) {
operand_shapes = {operand_shape};
} else {
absl::c_transform(operand_shape->tuple_shapes(),
std::back_inserter(operand_shapes),
[](const Shape& shape) { return &shape; });
}
CHECK_GT(operand_shapes.size(), 0);
if (!inplace) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
Expand All @@ -2739,17 +2748,6 @@ ShapeInference::InferCollectiveBroadcastShape(
}
}

/* static */ absl::StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
const Shape* const operand_shape, bool inplace) {
CHECK(operand_shape->IsTuple());
std::vector<const Shape*> operand_shapes;
absl::c_transform(operand_shape->tuple_shapes(),
std::back_inserter(operand_shapes),
[](const Shape& shape) { return &shape; });
CHECK(operand_shapes.size() > 1);
return InferCollectivePermuteShape(operand_shapes, inplace);
}

/* static */ absl::StatusOr<Shape>
ShapeInference::InferCollectivePermuteStartShape(
absl::Span<const Shape* const> operand_shapes,
Expand Down Expand Up @@ -2782,7 +2780,7 @@ ShapeInference::InferCollectivePermuteStartShape(
absl::c_transform(operand_shape->tuple_shapes(),
std::back_inserter(operand_shapes),
[](const Shape& shape) { return &shape; });
CHECK(operand_shapes.size() > 1);
CHECK_GT(operand_shapes.size(), 1);
return InferCollectivePermuteStartShape(operand_shapes, context_shapes,
inplace);
}
Expand Down
2 changes: 0 additions & 2 deletions xla/service/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,6 @@ class ShapeInference {
absl::Span<const Shape* const> operand_shapes);

// Infers the shape of a collective permute operation.
static absl::StatusOr<Shape> InferCollectivePermuteShape(
absl::Span<const Shape* const> operand_shapes, bool inplace = false);
static absl::StatusOr<Shape> InferCollectivePermuteShape(
const Shape* const operand_shape, bool inplace = false);

Expand Down

0 comments on commit c9202fa

Please sign in to comment.