Skip to content

Commit

Permalink
Rename convertToNewSharding to convertToSdySharding to better ref…
Browse files Browse the repository at this point in the history
…lect its purpose.

PiperOrigin-RevId: 675622382
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Sep 17, 2024
1 parent 8ace4ee commit 6be9e7b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ MeshAxesAndIds findMeshAxesAndIds(ModuleOp moduleOp) {

// Convert the `hloSharding` into a `TensorShardingAttr` based on the
// `globalMesh`.
TensorShardingAttr convertToNewSharding(
TensorShardingAttr convertToSdySharding(
const xla::HloSharding& hloSharding, MeshAttr globalMesh,
const SmallDenseMap<int64_t, StringRef>& deviceIdToMaximalMeshName,
int64_t rank, bool openDims) {
Expand Down Expand Up @@ -510,7 +510,7 @@ LogicalResult importShardings(
funcOp.getArgAttrOfType<StringAttr>(argNum, kXlaShardingAttr)) {
funcOp.setArgAttr(
argNum, kShardingAttr,
convertToNewSharding(parseShardingFromString(oldSharding), globalMesh,
convertToSdySharding(parseShardingFromString(oldSharding), globalMesh,
deviceIdToMaximalMeshName,
mlir::cast<ShapedType>(argType).getRank(),
shouldOpenDims(allowPropagationToArgs, argNum)));
Expand All @@ -523,7 +523,7 @@ LogicalResult importShardings(
funcOp.getResultAttrOfType<StringAttr>(resNum, kXlaShardingAttr)) {
funcOp.setResultAttr(
resNum, kShardingAttr,
convertToNewSharding(
convertToSdySharding(
parseShardingFromString(oldSharding), globalMesh,
deviceIdToMaximalMeshName,
mlir::cast<ShapedType>(resType).getRank(),
Expand All @@ -544,7 +544,7 @@ LogicalResult importShardings(
newShardings.reserve(op->getNumResults());
for (const auto& [resHloSharding, resType] :
llvm::zip_equal(flatHloSharding, op->getResultTypes())) {
newShardings.push_back(convertToNewSharding(
newShardings.push_back(convertToSdySharding(
resHloSharding, globalMesh, deviceIdToMaximalMeshName,
mlir::cast<ShapedType>(resType).getRank(),
/*openDims=*/false));
Expand Down
2 changes: 1 addition & 1 deletion xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ xla::HloSharding parseShardingFromString(const mlir::StringAttr& sharding);
//
// If `hloSharding` is unknown, return fully open sharding. Otherwise, the
// returned sharding is open iff `openDims` is true.
mlir::sdy::TensorShardingAttr convertToNewSharding(
mlir::sdy::TensorShardingAttr convertToSdySharding(
const xla::HloSharding& hloSharding, mlir::sdy::MeshAttr globalMesh,
const llvm::SmallDenseMap<int64_t, mlir::StringRef>&
deviceIdToMaximalMeshName,
Expand Down

0 comments on commit 6be9e7b

Please sign in to comment.