Skip to content

Commit

Permalink
Fix an ALA bug related to unaligned followers (chapel-lang#25753)
Browse files Browse the repository at this point in the history
Resolves chapel-lang#25727

The following code would fail because of an ALA bug:

```chpl
use BlockDist;

var A: [blockDist.createDomain({1..100})] int = 5;

var db = {1..50};
var B: [blockDist.createDomain(db)] int = 1;

forall (b, i) in zip(B, db) {
  b = A[i];
}

writeln(B);
```

The root cause was that when checking for `A[i]`s alignment, we would
check against `db` instead of `B`. What's tricky here is that `db` is
actually "aligned" with `A` as it happens to be its local subdomain.
However, that only matters if `db` was actually the leader of the loop.

This PR fixes the logic in the ALA implementation. The core change is
that our static and dynamic checks for ALA would accept (1) accessBase
(`A`) and (2) the loop domain (used to be `db`, erroneously). After this
PR, they take (1) accessBase (`A`), (2) the loop domain (`B`,
correctly), and (3) iterand of the accessBase (`db`). Static and dynamic
check logic is adjusted accordingly.

While there this PR:

- does a very mild refactor to move the common (problematic) logic
between static and dynamic checks into a helper function in the
compiler.
- renames module helpers to have `chpl__ala_` prefix for consistency
internally within the module and with other optimizations.

[Reviewed by @benharsh]

Test
- [x] linux64
- [x] gasnet
  • Loading branch information
e-kayrakli authored Aug 29, 2024
2 parents 2f22f74 + 122198a commit 80308be
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 59 deletions.
3 changes: 3 additions & 0 deletions compiler/include/ForallStmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class ForallOptimizationInfo {
ForallAutoLocalAccessCloneType cloneType;

ForallOptimizationInfo();

Expr* getIterand(int idx);
Expr* getLoopDomainExpr();
};

///////////////////////////////////
Expand Down
55 changes: 38 additions & 17 deletions compiler/optimizations/forallOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,25 @@ void ALACandidate::addOffset(Expr* e) {
offsetExprs_.push_back(zero);
}
}
Expr* ForallOptimizationInfo::getLoopDomainExpr() {
// loop domain is the 0th iterand
return this->getIterand(0);
}

Expr* ForallOptimizationInfo::getIterand(int idx) {
if (auto &s = iterSym[idx]) {
return new SymExpr(s);
}
else if (auto &e = dotDomIterExpr[idx]) {
return e->copy();
}
else if (auto &s = iterCallTmp[idx]) {
return new SymExpr(s);
}
else {
return nullptr;
}
}

// for a call like `A[i]`, this will create something like
//
Expand All @@ -1192,8 +1211,8 @@ static void generateDynamicCheckForAccess(ALACandidate& candidate,
CallExpr *&allChecks) {
ForallOptimizationInfo &optInfo = forall->optInfo;
Symbol *baseSym = candidate.getCallBase();
int iterandIdx = candidate.getIterandIdx();
INT_ASSERT(baseSym);
const int iterandIdx = candidate.getIterandIdx();

auto& staticCheckSymMap = candidate.hasOffset() ?
optInfo.staticCheckWOffSymForSymMap :
Expand All @@ -1202,19 +1221,20 @@ static void generateDynamicCheckForAccess(ALACandidate& candidate,
SET_LINENO(forall);

if (optInfo.dynamicCheckForSymMap.count(baseSym) == 0) {
CallExpr* check = new CallExpr("chpl__dynamicAutoLocalCheck");
CallExpr* check = new CallExpr("chpl__ala_dynamicCheck");
optInfo.dynamicCheckForSymMap[baseSym] = check;

check->insertAtTail(baseSym);

if (optInfo.iterSym[iterandIdx] != NULL) {
check->insertAtTail(new SymExpr(optInfo.iterSym[iterandIdx]));
if (Expr* e = optInfo.getLoopDomainExpr()) {
check->insertAtTail(e);
}
else if (optInfo.dotDomIterExpr[iterandIdx] != NULL) {
check->insertAtTail(optInfo.dotDomIterExpr[iterandIdx]->copy());
else {
INT_FATAL("optInfo didn't have enough information");
}
else if (optInfo.iterCallTmp[iterandIdx] != NULL) {
check->insertAtTail(new SymExpr(optInfo.iterCallTmp[iterandIdx]));

if (Expr* e = optInfo.getIterand(iterandIdx)) {
check->insertAtTail(e);
}
else {
INT_FATAL("optInfo didn't have enough information");
Expand Down Expand Up @@ -1260,7 +1280,7 @@ static Symbol *generateStaticCheckForAccess(ALACandidate& candidate,

ForallOptimizationInfo &optInfo = forall->optInfo;
Symbol *baseSym = candidate.getCallBase();
const int iterandIdx = candidate.getIterandIdx();
int iterandIdx = candidate.getIterandIdx();
INT_ASSERT(baseSym);

auto& staticCheckSymMap = candidate.hasOffset() ?
Expand All @@ -1270,24 +1290,25 @@ static Symbol *generateStaticCheckForAccess(ALACandidate& candidate,
if (staticCheckSymMap.count(baseSym) == 0) {
SET_LINENO(forall);

VarSymbol *checkSym = new VarSymbol("chpl__staticAutoLocalCheckSym");
VarSymbol *checkSym = new VarSymbol("chpl__ala_staticCheckSym");
checkSym->addFlag(FLAG_PARAM);
// mark it with FLAG_TEMP to prevent the normalizer from adding
// PRIM_END_OF_STATEMENT in the wrong places for loops.
checkSym->addFlag(FLAG_TEMP);
staticCheckSymMap[baseSym] = checkSym;

CallExpr *checkCall = new CallExpr("chpl__staticAutoLocalCheck");
CallExpr *checkCall = new CallExpr("chpl__ala_staticCheck");
checkCall->insertAtTail(baseSym);

if (optInfo.iterSym[iterandIdx] != NULL) {
checkCall->insertAtTail(new SymExpr(optInfo.iterSym[iterandIdx]));
if (Expr* e = optInfo.getLoopDomainExpr()) {
checkCall->insertAtTail(e);
}
else if (optInfo.dotDomIterExpr[iterandIdx] != NULL) {
checkCall->insertAtTail(optInfo.dotDomIterExpr[iterandIdx]->copy());
else {
INT_FATAL("optInfo didn't have enough information");
}
else if (optInfo.iterCallTmp[iterandIdx] != NULL) {
checkCall->insertAtTail(new SymExpr(optInfo.iterCallTmp[iterandIdx]));

if (Expr* e = optInfo.getIterand(iterandIdx)) {
checkCall->insertAtTail(e);
}
else {
INT_FATAL("optInfo didn't have enough information");
Expand Down
76 changes: 34 additions & 42 deletions modules/internal/ChapelAutoLocalAccess.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ module ChapelAutoLocalAccess {
// note that the compiler can pass an iterator to `loopDomain` argument. Make
// sure that we don't do anything with iterators as we cannot optimize such
// forall's and we don't want to mess up the iterator
proc chpl__staticAutoLocalCheck(accessBase: [], loopDomain: domain,
param hasOffsets=false) param {
proc chpl__ala_staticCheck(accessBase: [], loopDomain: domain,
myIterand: domain, param hasOffsets=false) param {
if hasOffsets && !accessBase.domain.supportsOffsetAutoLocalAccess() {
return false;
}
Expand All @@ -41,35 +41,22 @@ module ChapelAutoLocalAccess {

// support forall i in a.domain.localSubdomain() do .... a[i] ....
if loopDomain._value.type.isDefaultRectangular() {
return accessBase.domain.supportsAutoLocalAccess();
return accessBase.domain.supportsAutoLocalAccess() &&
accessBase.rank == loopDomain.rank;
}

return false;
}

proc chpl__staticAutoLocalCheck(accessBase, loopDomain,
param hasOffsets=false) param {
return false;
}

// these type overloads are for degenerate cases where the optimization can
// break a meaningful error message without these
proc chpl__staticAutoLocalCheck(type accessBase, type loopDomain,
param hasOffsets=false) param {
return false;
}
proc chpl__staticAutoLocalCheck(accessBase, type loopDomain,
param hasOffsets=false) param {
return false;
}
proc chpl__staticAutoLocalCheck(type accessBase, loopDomain,
param hasOffsets=false) param {
return false;
proc chpl__ala_staticCheck(accessBase: [], loopDomain: [],
myIterand: domain, param hasOffsets=false) param {
return chpl__ala_staticCheck(accessBase, loopDomain.domain, myIterand,
hasOffsets);
}

proc chpl__dynamicAutoLocalCheck(accessBase, loopDomain,
param hasOffsets=false) {
if chpl__staticAutoLocalCheck(accessBase, loopDomain, hasOffsets) {
proc chpl__ala_dynamicCheck(accessBase: [], loopDomain: domain,
myIterand: domain, param hasOffsets=false) {
if chpl__ala_staticCheck(accessBase, loopDomain, myIterand, hasOffsets) {
// if they're the same domain...
if chpl_sameDomainKind(accessBase.domain, loopDomain) &&
accessBase.domain == loopDomain &&
Expand All @@ -85,9 +72,9 @@ module ChapelAutoLocalAccess {
//
// Be also aware that `subset` call below can be expensive if we are not
// calling on default rectangular
if loopDomain._value.type.isDefaultRectangular() {
if loopDomain.locale == here {
if accessBase.localSubdomain().contains(loopDomain) {
if myIterand._value.type.isDefaultRectangular() {
if myIterand.locale == here {
if accessBase.localSubdomain().contains(myIterand) {
return true;
}
}
Expand All @@ -97,21 +84,6 @@ module ChapelAutoLocalAccess {
return false;
}

// these type overloads are for degenerate cases where the optimization can
// break a meaningful error message without these
proc chpl__dynamicAutoLocalCheck(type accessBase, type loopDomain,
param hasOffsets=false) {
return false;
}
proc chpl__dynamicAutoLocalCheck(accessBase, type loopDomain,
param hasOffsets=false) {
return false;
}
proc chpl__dynamicAutoLocalCheck(type accessBase, loopDomain,
param hasOffsets=false) {
return false;
}

inline proc chpl__ala_offsetCheck(accessBase: [], offsets:integral...) {
if (offsets.size != accessBase.rank) {
compilerError("Automatic local access optimization failure: ",
Expand Down Expand Up @@ -140,4 +112,24 @@ module ChapelAutoLocalAccess {
}

}

// these type overloads are for degenerate cases where the optimization can
// break a meaningful error message without these
proc chpl__ala_staticCheck(type a, type l, type m, param h=false) param do return false;
proc chpl__ala_staticCheck(type a, type l, m, param h=false) param do return false;
proc chpl__ala_staticCheck(type a, l, type m, param h=false) param do return false;
proc chpl__ala_staticCheck(type a, l, m, param h=false) param do return false;
proc chpl__ala_staticCheck( a, type l, type m, param h=false) param do return false;
proc chpl__ala_staticCheck( a, type l, m, param h=false) param do return false;
proc chpl__ala_staticCheck( a, l, type m, param h=false) param do return false;
proc chpl__ala_staticCheck( a, l, m, param h=false) param do return false;
proc chpl__ala_dynamicCheck(type a, type l, type m, param h=false) do return false;
proc chpl__ala_dynamicCheck(type a, type l, m, param h=false) do return false;
proc chpl__ala_dynamicCheck(type a, l, type m, param h=false) do return false;
proc chpl__ala_dynamicCheck(type a, l, m, param h=false) do return false;
proc chpl__ala_dynamicCheck( a, type l, type m, param h=false) do return false;
proc chpl__ala_dynamicCheck( a, type l, m, param h=false) do return false;
proc chpl__ala_dynamicCheck( a, l, type m, param h=false) do return false;
proc chpl__ala_dynamicCheck( a, l, m, param h=false) do return false;

}
14 changes: 14 additions & 0 deletions test/optimizations/autoLocalAccess/unalignedFollower.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// we want dynamic checks to fail here. The way to check is to make sure that
// `localAccess` is never called.
use BlockDist;

var A: [blockDist.createDomain({1..10})] int = 5;

var db = {1..5};
var B: [blockDist.createDomain(db)] int = 1;

forall (b, i) in zip(B, db) {
b = A[i];
}

writeln(B);
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-slogAllArrEltAccess=true
17 changes: 17 additions & 0 deletions test/optimizations/autoLocalAccess/unalignedFollower.good
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

Start analyzing forall (unalignedFollower.chpl:10)
| Found loop domain (unalignedFollower.chpl:8)
| Will attempt static and dynamic optimizations (unalignedFollower.chpl:10)
|
| Start analyzing call (unalignedFollower.chpl:11)
| Can't determine the domain of access base (unalignedFollower.chpl:5)
| This call is a dynamic optimization candidate (unalignedFollower.chpl:11)
|
End analyzing forall (unalignedFollower.chpl:10)

Static check successful. Using localAccess with dynamic check (unalignedFollower.chpl:11)
Static check successful. Using localAccess with dynamic check (unalignedFollower.chpl:11)
5 5 5 5 5

Numbers collected by prediff:
localAccess was called 0 times
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
2
3 changes: 3 additions & 0 deletions test/optimizations/autoLocalAccess/unalignedFollower.prediff
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/sh

./PREDIFF-filter-accessors $1 $2 --no-this

0 comments on commit 80308be

Please sign in to comment.