Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657736907
  • Loading branch information
tomnatan30 authored and copybara-github committed Jul 30, 2024
1 parent 526fb4d commit 28e858d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const {
UpdateTensorShardings result{
UpdateTensorShardings result_2{
.updateOperands = BitVector(projection.getNumOperands()),
.updateResults = BitVector(projection.getNumResults())};

Expand All @@ -405,11 +405,11 @@ UpdateTensorShardings BasicFactorPropagation::propagateFactorShardings(
auto [updateOperandForFactor, updateResultForFactor] =
projection.updateSharding(factorIndex, axesToPropagate);

result.updateOperands |= updateOperandForFactor;
result.updateResults |= updateResultForFactor;
result_2.updateOperands |= updateOperandForFactor;
result_2.updateResults |= updateResultForFactor;
}

return result;
return result_2;
}

} // namespace sdy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ LogicalResult propagateTensorShardings(
MeshAttr mesh = getMeshAttr(op, meshName);
assert(mesh && "unknown mesh");

ShardingProjection shardingProjection = ShardingProjection::build(
ShardingProjection sharding_projection = ShardingProjection::build(
operandShardings, resultsShardings, shardingRule, mesh);

auto [updateOperand, updateResult] =
factorPropagation.propagateFactorShardings(
shardingProjection, direction, shardingRule.getFactorSizes(), mesh,
sharding_projection, direction, shardingRule.getFactorSizes(), mesh,
op, conservativePropagation);

// We need to update the tensor sharding attributes explicitly, as we have
Expand All @@ -268,7 +268,7 @@ LogicalResult propagateTensorShardings(
}
updateTensorShardings(operands, results, operandShardings, resultsShardings,
setOperandShardingCallback, setResultShardingCallback,
shardingRule, shardingProjection, updateOperand,
shardingRule, sharding_projection, updateOperand,
updateResult, meshName, mesh, notifyOpModified);

bool anyUpdated = updateOperand.any() || updateResult.any();
Expand Down

0 comments on commit 28e858d

Please sign in to comment.