Skip to content

Commit

Permalink
dyno: fix 'called expression' detection (chapel-lang#24723)
Browse files Browse the repository at this point in the history
Fixes Cray/chapel-private#6077.

Dyno handles resolving base expressions (e.g., the `f` in `f(x)`)
differently than plain expressions; this is because, although `a.b`
should be a field access, if we know that it's actually `a.b(c, d)`, we
need to resolve the method `b` on `a` with arguments `c` and `d` (which
happens when the call itself is resolved). Thus, resolution of `a.b` is
deferred.

However, Dyno's detection of when an expression is the "called
expression" is incorrect. In particular, it doesn't work for nested
called expressions, like `a.b(x).c(y)`. This is because the `Resolver`
only uses a single state variable, `inLeafCall`, which it sets on
entering a call, and resets to `nullptr` upon exiting a call.
Unfortunately, since the entering and exiting are different functions
called at different times, `exit` doesn't know the original / prior
value of `inLeafCall`, and as a result, resets it to `nullptr`. This
means that in the nested call case, when handling the outer call,
`inLeafCall` is incorrectly `nullptr`.

This PR fixes the issue by switching to using a stack of called
expressions. By doing so, we can preserve the original value of what
used to be `inLeafCall`, and thus correctly detect called
sub-expressions even when calls are chained / nested. While there, this
PR ensures that resetting `inLeafCall` is properly _unset_. Some
early-return logic in `exit(Call)` skips unsetting it. To handle this
while preserving the elegance of the early-return logic, I put the
`exit(Call)` logic into a helper function, and always invoke `pop_back`
after calling the helper.

Reviewed by @benharsh -- thanks!

## Testing
- [x] paratest
  • Loading branch information
DanilaFe authored Mar 29, 2024
2 parents f47881f + 0ca9c31 commit 24b96ca
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 12 deletions.
32 changes: 21 additions & 11 deletions frontend/lib/resolution/Resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ Resolver::paramLoopResolver(Resolver& parent,
return ret;
}

const AstNode* Resolver::nearestCalledExpression() const {
if (callNodeStack.empty()) return nullptr;

return callNodeStack.back()->calledExpression();
}

void Resolver::setCompositeType(const CompositeType* ct) {
CHPL_ASSERT(this->inCompositeType == nullptr);
this->inCompositeType = ct;
Expand Down Expand Up @@ -2552,8 +2558,7 @@ Resolver::lookupIdentifier(const Identifier* ident,
CHPL_ASSERT(scopeStack.size() > 0);
const Scope* scope = scopeStack.back();

bool resolvingCalledIdent = (inLeafCall &&
ident == inLeafCall->calledExpression());
bool resolvingCalledIdent = nearestCalledExpression() == ident;

LookupConfig config = LOOKUP_DECLS |
LOOKUP_IMPORT_AND_USE |
Expand Down Expand Up @@ -2842,8 +2847,8 @@ void Resolver::resolveIdentifier(const Identifier* ident,
// record R { type t = int; }
// var x: R; // should refer to R(int)
bool computeDefaults = true;
bool resolvingCalledIdent = (inLeafCall &&
ident == inLeafCall->calledExpression());
bool resolvingCalledIdent = nearestCalledExpression() == ident;

if (resolvingCalledIdent) {
computeDefaults = false;
}
Expand Down Expand Up @@ -2962,7 +2967,7 @@ bool Resolver::enter(const TypeQuery* tq) {

if (!foundFormalSubstitution) {
// No substitution (i.e. initial signature) so use AnyType
if (inLeafCall && isCallToIntEtc(inLeafCall)) {
if (!callNodeStack.empty() && isCallToIntEtc(callNodeStack.back())) {
auto defaultInt = IntType::get(context, 0);
result.setType(QualifiedType(QualifiedType::PARAM, defaultInt));
} else {
Expand Down Expand Up @@ -3329,7 +3334,7 @@ types::QualifiedType Resolver::typeForBooleanOp(const uast::OpCall* op) {
}

bool Resolver::enter(const Call* call) {
inLeafCall = call;
callNodeStack.push_back(call);
auto op = call->toOpCall();

if (op && initResolver) {
Expand Down Expand Up @@ -3371,9 +3376,10 @@ void Resolver::prepareCallInfoActuals(const Call* call,
/* actualAsts */ nullptr);
}

void Resolver::exit(const Call* call) {
if (scopeResolveOnly)
void Resolver::handleCallExpr(const uast::Call* call) {
if (scopeResolveOnly) {
return;
}

if (initResolver && initResolver->handleResolvingCall(call))
return;
Expand Down Expand Up @@ -3484,8 +3490,13 @@ void Resolver::exit(const Call* call) {
ResolvedExpression& r = byPostorder.byAst(call);
r.setType(QualifiedType());
}
}

void Resolver::exit(const Call* call) {
handleCallExpr(call);

inLeafCall = nullptr;
// Always remove the call from the stack to make sure it's properly set.
callNodeStack.pop_back();
}

bool Resolver::enter(const Dot* dot) {
Expand Down Expand Up @@ -3565,8 +3576,7 @@ void Resolver::exit(const Dot* dot) {

ResolvedExpression& receiver = byPostorder.byAst(dot->receiver());

bool resolvingCalledDot = (inLeafCall &&
dot == inLeafCall->calledExpression());
bool resolvingCalledDot = nearestCalledExpression() == dot;
if (resolvingCalledDot && !scopeResolveOnly) {
// We will handle it when resolving the FnCall.

Expand Down
13 changes: 12 additions & 1 deletion frontend/lib/resolution/Resolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct Resolver {
std::set<ID> fieldOrFormals;
std::set<ID> instantiatedFieldOrFormals;
std::set<UniqueString> namesWithErrorsEmitted;
const uast::Call* inLeafCall = nullptr;
std::vector<const uast::Call*> callNodeStack;
bool receiverScopesComputed = false;
ReceiverScopesVec savedReceiverScopes;
Resolver* parentResolver = nullptr;
Expand Down Expand Up @@ -206,6 +206,12 @@ struct Resolver {
const uast::For* loop,
ResolutionResultByPostorderID& bodyResults);

/**
During AST traversal, find the last called expression we entered.
e.g., will return 'f' if we just entered 'f()'.
*/
const chpl::uast::AstNode* nearestCalledExpression() const;

// Set the composite type of this Resolver. It is an error to call this
// method when a composite type is already set.
void setCompositeType(const types::CompositeType* ct);
Expand Down Expand Up @@ -336,6 +342,11 @@ struct Resolver {
ID moduleId,
LookupConfig failedConfig);

// after resolving the child nodes of the call as needed, perform call resolution
// if appropriate. This is a helper function because it has some complicated
// control flow, and we want to make sure to always keep callNodeStack in sync.
void handleCallExpr(const uast::Call* call);

// handle the result of one of the functions to resolve a call. Handles:
// * r.setMostSpecific
// * r.setPoiScope
Expand Down
67 changes: 67 additions & 0 deletions frontend/test/resolution/testResolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,72 @@ static void test22() {
}
}

static void test23() {
Context ctx;
Context* context = &ctx;
ErrorGuard guard(context);

{
std::string prog =
R"""(
record R {
var x : int;
proc foo() {
return 5;
}
}
proc helper() {
var r : R;
return r;
}
proc foo() {
return "hello";
}
// should be an int, not a string.
var x = helper().foo();
)""";

auto t = resolveTypeOfXInit(context, prog);
assert(t.type());
assert(t.type()->isIntType());
assert(t.type()->toIntType()->isDefaultWidth());
}

{
context->advanceToNextRevision(false);
std::string prog =
R"""(
record Inner {
var x : int;
proc innerFoo() {
return x;
}
}
record Outer {
var inner : Inner;
proc helper() const ref {
return inner;
}
}
var o : Outer;
var x = o.helper().innerFoo();
)""";

auto t = resolveTypeOfXInit(context, prog);
assert(t.type());
assert(t.type()->isIntType());
assert(t.type()->toIntType()->isDefaultWidth());
}
}

int main() {
test1();
test2();
Expand All @@ -1365,6 +1431,7 @@ int main() {
test20();
test21();
test22();
test23();

return 0;
}

0 comments on commit 24b96ca

Please sign in to comment.