From 042fc5fb314cf957f584b4fb81504c2a50caac9c Mon Sep 17 00:00:00 2001 From: eaplatanios Date: Mon, 29 Jul 2024 20:46:30 -0700 Subject: [PATCH] Fixed compilation on Windows. --- .../propagation/aggressive_factor_propagation.cc | 4 ++-- .../transforms/propagation/basic_factor_propagation.cc | 4 ++-- .../sdy/transforms/propagation/basic_propagation.cc | 4 ++-- .../sdy/transforms/propagation/sharding_projection.cc | 8 ++++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc index 72dd5e5..d048c1c 100644 --- a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc @@ -149,8 +149,8 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings( ArrayRef factorSizes, MeshAttr mesh, Operation* op, bool conservativePropagation) const { UpdateTensorShardings result{ - .updateOperands = BitVector(projection.getNumOperands()), - .updateResults = BitVector(projection.getNumResults())}; + /* .updateOperands = */ BitVector(projection.getNumOperands()), + /* .updateResults = */ BitVector(projection.getNumResults())}; // We get the compatible major sharding axes for all factors. AxesPerFactor axesPerFactor = getCompatibleMajorShardingAxesForAllFactors( diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc index 78954d6..91ada00 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc @@ -409,8 +409,8 @@ UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings( ArrayRef factorSizes, MeshAttr mesh, Operation* op, bool conservativePropagation) const { UpdateTensorShardings result{ - .updateOperands = BitVector(projection.getNumOperands()), - .updateResults = BitVector(projection.getNumResults())}; + /* .updateOperands = */ BitVector(projection.getNumOperands()), + /* .updateResults = */ BitVector(projection.getNumResults())}; // We propagate each factor separately. for (auto [factorIndex, factorSize] : llvm::enumerate(factorSizes)) { diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index 2947c3d..1432489 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -612,8 +612,8 @@ LogicalResult BasicPropagationPassImpl::propagate( // convergence), since we make sure ops whose sharding changes are // added back to the worklist. GreedyRewriteConfig config{ - .useTopDownTraversal = true, - .enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled}; + /* .useTopDownTraversal = */ true, + /* .enableRegionSimplification = */ mlir::GreedySimplifyRegionLevel::Disabled}; if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns), config))) { return failure(); diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc b/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc index 111372e..d89cfdf 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc @@ -146,8 +146,8 @@ TensorShardingAttr TensorFactorShardings::createTensorShardingAttr( UpdateShardings ShardingProjection::updateSharding( int64_t factorIndex, ArrayRef newAxes) { - UpdateShardings result{.updateOperands = BitVector(getNumOperands()), - .updateResults = BitVector(getNumResults())}; + UpdateShardings result{/* .updateOperands = */ BitVector(getNumOperands()), + /* .updateResults = */ BitVector(getNumResults())}; for (auto [i, tensor] : llvm::enumerate(operands)) { result.updateOperands[i] = tensor.updateShardingAxes(factorIndex, newAxes); } @@ -180,8 +180,8 @@ std::optional getAxisRefInfo(ArrayRef axes, AxisRefAttr axisRef = axes[axisIndex]; SubAxisInfoAttr splitInfo = axisRef.getSubAxisInfo(); return AxisRefInfo{ - .size = axisRef.getSize(mesh), - .splitPreSize = splitInfo ? std::make_optional(splitInfo.getPreSize()) + /* .size = */ axisRef.getSize(mesh), + /* .splitPreSize = */ splitInfo ? std::make_optional(splitInfo.getPreSize()) : std::nullopt}; }