Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for checking integral return types #26603

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions compiler/resolution/resolveFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1965,12 +1965,12 @@ void resolveIfExprType(CondStmt* stmt) {
retType = thenType;
} else {
bool promote = false;

if (canDispatch(elseType, elseSym, thenType, NULL, fn, &promote) &&
promote == false) {
bool paramNarrows = false;
if (canDispatch(elseType, elseSym, thenType, NULL, fn, &promote, &paramNarrows) &&
promote == false && !paramNarrows) {
retType = thenType;
} else if (canDispatch(thenType, thenSym, elseType, NULL, fn, &promote) &&
promote == false) {
} else if (canDispatch(thenType, thenSym, elseType, NULL, fn, &promote, &paramNarrows) &&
promote == false && !paramNarrows) {
retType = elseType;
}
}
Expand Down Expand Up @@ -2102,19 +2102,24 @@ void resolveReturnTypeAndYieldedType(FnSymbol* fn, Type** yieldedType) {
for (int j = 0; j < retTypes.n; j++) {
if (retTypes.v[i] != retTypes.v[j]) {
bool requireScalarPromotion = false;
bool paramNarrows = false;
if (canDispatch(retTypes.v[j],
retSymbols.v[j],
retTypes.v[i],
NULL,
fn,
&requireScalarPromotion,
&paramNarrows) == false) {
best = false;
}

if (canDispatch(retTypes.v[j],
retSymbols.v[j],
retTypes.v[i],
NULL,
fn,
&requireScalarPromotion) == false) {
best = false;
}
if (requireScalarPromotion) {
best = false;
}

if (requireScalarPromotion) {
best = false;
}
if (paramNarrows && fn->retTag != RET_PARAM) {
best = false;
}
}
}

Expand Down
1 change: 1 addition & 0 deletions frontend/test/resolution/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ comp_unit_test(testGetSymbolsAvailableInScope)
comp_unit_test(testHeapBuffer)
comp_unit_test(testIf)
comp_unit_test(testInitSemantics)
comp_unit_test(testIntegralReturnType)
comp_unit_test(testInteractive)
comp_unit_test(testInterfaces)
comp_unit_test(testIterators)
Expand Down
230 changes: 230 additions & 0 deletions frontend/test/resolution/testIntegralReturnType.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright 2021-2025 Hewlett Packard Enterprise Development LP
* Other additional copyright holders may be indicated within.
*
* The entirety of this work is licensed under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
*
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "test-resolution.h"

#include "chpl/parsing/parsing-queries.h"
#include "chpl/resolution/resolution-queries.h"
#include "chpl/resolution/scope-queries.h"
#include "chpl/types/all-types.h"
#include "chpl/uast/Identifier.h"
#include "chpl/uast/Module.h"
#include "chpl/uast/Record.h"
#include "chpl/uast/Variable.h"

static std::string buildControlFlowProgram(std::string existingProgram, std::string boolVal, std::string typeVal1, std::string typeVal2) {
std::string program = existingProgram;
program += "\nvar x = f(";
program += boolVal;
program += ", ";
program += typeVal1;
program += ", ";
program += typeVal2;
program += ");\n";
return program;
}

// need param and non-param variants
static std::string buildIfExpressionParam(std::string boolVal, std::string typeVal1, std::string typeVal2) {
std::string program = "";
program += "var x = if ";
program += boolVal;
program += " then 0:";
program += typeVal1;
program += " else 0:";
program += typeVal2;
program += ";\n";
return program;
}

static std::string buildIfExpressionNoParam(std::string boolVal, std::string typeVal1, std::string typeVal2) {
std::string program = "";
program += "var zero:uint(8);\n";
program += "var x = if ";
program += boolVal;
program += " then zero:";
program += typeVal1;
program += " else zero:";
program += typeVal2;
program += ";\n";
return program;
}

enum class ProgramTestType {
TestGeneric,
TestGenericCast,
TestParamGenericCast,
TestExpression,
TestParamExpression
};

static void testIntegralReturn(std::string baseProgram, std::string boolVal,
std::string typeVal1, std::string typeVal2,
const chpl::types::PrimitiveType* expectedType,
ProgramTestType testType) {
auto context = buildStdContext();
ErrorGuard guard(context);
std::string program;
std::string lineOut = "";
if (testType == ProgramTestType::TestGeneric) {
program = buildControlFlowProgram(baseProgram, boolVal, typeVal1, typeVal2);
lineOut += "testGeneric: when ";
} else if (testType == ProgramTestType::TestGenericCast) {
program = buildControlFlowProgram(baseProgram, boolVal, typeVal1, typeVal2);
lineOut += "testGenericCast: when ";
} else if (testType == ProgramTestType::TestParamGenericCast) {
program = buildControlFlowProgram(baseProgram, boolVal, typeVal1, typeVal2);
lineOut += "testParamGenericCast: when ";
} else if (testType == ProgramTestType::TestExpression) {
program = buildIfExpressionNoParam(boolVal, typeVal1, typeVal2);
lineOut += "testExpression: when ";
} else if (testType == ProgramTestType::TestParamExpression) {
program = buildIfExpressionParam(boolVal, typeVal1, typeVal2);
lineOut += "testParamExpression: when ";
}
lineOut += boolVal;
lineOut += " ";
lineOut += typeVal1;
lineOut += " ";
lineOut += typeVal2;
lineOut += " then ";
auto returnType = resolveTypeOfXInit(context, program, true);
assert(returnType.type());
std::cout << lineOut;
returnType.type()->dump();
// re-enable this when we have the correct types for each call
// assert(returnType.type() == expectedType);
}

static void
testIfExpressionIntegralTypesHelper(std::string program,
std::string boolCondition,
ProgramTestType testType) {
auto context = buildStdContext();
ErrorGuard guard(context);
// use only ints. 8, 16, 32, 64 bit. use each size with every other size.
testIntegralReturn(program, boolCondition, "int(8)", "int(8)", IntType::get(context, 8), testType);
testIntegralReturn(program, boolCondition, "int(8)", "int(16)", IntType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "int(8)", "int(32)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(8)", "int(64)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(16)", "int(8)", IntType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "int(16)", "int(16)", IntType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "int(16)", "int(32)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(16)", "int(64)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(32)", "int(8)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(32)", "int(16)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(32)", "int(32)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(32)", "int(64)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(64)", "int(8)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(64)", "int(16)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(64)", "int(32)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(64)", "int(64)", IntType::get(context, 64), testType);
// do the uint versions of all the above, true and false versions
testIntegralReturn(program, boolCondition, "uint(8)", "uint(8)", UintType::get(context, 8), testType);
testIntegralReturn(program, boolCondition, "uint(8)", "uint(16)", UintType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "uint(8)", "uint(32)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(8)", "uint(64)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(16)", "uint(8)", UintType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "uint(16)", "uint(16)", UintType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "uint(16)", "uint(32)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(16)", "uint(64)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(32)", "uint(8)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(32)", "uint(16)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(32)", "uint(32)", UintType::get(context, 32), testType);;
testIntegralReturn(program, boolCondition, "uint(32)", "uint(64)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(64)", "uint(8)", UintType::get(context, 64), testType);;
testIntegralReturn(program, boolCondition, "uint(64)", "uint(16)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(64)", "uint(32)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(64)", "uint(64)", UintType::get(context, 64), testType);
// tests where the first type is int and the second is uint, do all sizes of ints and uints
testIntegralReturn(program, boolCondition, "int(8)", "uint(8)", UintType::get(context, 8), testType);
testIntegralReturn(program, boolCondition, "int(8)", "uint(16)", UintType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "int(8)", "uint(32)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(8)", "uint(64)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(16)", "uint(8)", IntType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "int(16)", "uint(16)", UintType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "int(16)", "uint(32)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(16)", "uint(64)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(32)", "uint(8)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(32)", "uint(16)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(32)", "uint(32)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "int(32)", "uint(64)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(64)", "uint(8)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(64)", "uint(16)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(64)", "uint(32)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "int(64)", "uint(64)", UintType::get(context, 64), testType);
// reverse the order of the types so that the first type is uint and the second is int
testIntegralReturn(program, boolCondition, "uint(8)", "int(8)", UintType::get(context, 8), testType);
testIntegralReturn(program, boolCondition, "uint(8)", "int(16)", IntType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "uint(8)", "int(32)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(8)", "int(64)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(16)", "int(8)", UintType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "uint(16)", "int(16)", UintType::get(context, 16), testType);
testIntegralReturn(program, boolCondition, "uint(16)", "int(32)", IntType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(16)", "int(64)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(32)", "int(8)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(32)", "int(16)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(32)", "int(32)", UintType::get(context, 32), testType);
testIntegralReturn(program, boolCondition, "uint(32)", "int(64)", IntType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(64)", "int(8)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(64)", "int(16)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(64)", "int(32)", UintType::get(context, 64), testType);
testIntegralReturn(program, boolCondition, "uint(64)", "int(64)", UintType::get(context, 64), testType);
}

static void testConditionalIntegralTypes() {
std::string testGeneric =
R"""(
proc f(arg: bool, type t, type tt) {
var i: t = 0;
var u: tt = 0;
if arg then return i; else return u;
}
)""";
std::string testGenericCast =
R"""(
proc f(arg: bool, type t, type tt) {
var zero: uint(8);
if arg then return zero:t; else return zero:tt;
}
)""";
std::string testParamGenericCast =
R"""(
proc f(arg: bool, type t, type tt) {
if arg then return 0:t; else return 0:tt;
}
)""";

testIfExpressionIntegralTypesHelper(testGeneric, "true", ProgramTestType::TestGeneric);
testIfExpressionIntegralTypesHelper(testGeneric, "false", ProgramTestType::TestGeneric);
testIfExpressionIntegralTypesHelper(testParamGenericCast, "true", ProgramTestType::TestParamGenericCast);
testIfExpressionIntegralTypesHelper(testParamGenericCast, "false", ProgramTestType::TestParamGenericCast);
testIfExpressionIntegralTypesHelper(testGenericCast, "true", ProgramTestType::TestGenericCast);
testIfExpressionIntegralTypesHelper(testGenericCast, "false", ProgramTestType::TestGenericCast);
testIfExpressionIntegralTypesHelper("", "true", ProgramTestType::TestExpression);
testIfExpressionIntegralTypesHelper("", "false", ProgramTestType::TestExpression);
testIfExpressionIntegralTypesHelper("", "true", ProgramTestType::TestParamExpression);
testIfExpressionIntegralTypesHelper("", "false", ProgramTestType::TestParamExpression);
}

int main() {
testConditionalIntegralTypes();
return 0;
}

94 changes: 94 additions & 0 deletions test/statements/conditionals/integralReturnTypesInConditional.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
proc testGeneric(arg: bool, type t, type tt) {
var i: t = 0;
var u: tt = 0;
if arg then return i;
else return u;
}

proc testGenericCast(arg: bool, type t, type tt) {
var zero: uint(8);
if arg then return zero:t;
else return zero:tt;
}

proc testParamGenericCast(arg: bool, type t, type tt) {
if arg then return 0:t;
else return 0:tt;
}

proc testExpression(arg: bool, type t, type tt) {
var zero: uint(8);
var x = if arg then zero:t else zero:tt;
writeln("testExpression: when ", arg:string, " ", t:string, " ",
tt:string, " then ", x.type:string);
}

proc testParamExpression(arg: bool, type t, type tt) {
var x = if arg then 0:t else 0:tt;
writeln("testParamExpression: when ", arg:string, " ", t:string, " ",
tt:string, " then ", x.type:string);
}

const A = [true, false];

const Ti = (0:int(8), 0:int(16), 0:int(32), 0:int(64));
const Tu = (0:uint(8), 0:uint(16), 0:uint(32), 0:uint(64));

for a in A {
for ti in Ti {
for tti in Ti {
var xxx = testGeneric(a, ti.type, tti.type);
writeln("testGeneric: when ", a:string, " ", ti.type:string, " ",
tti.type:string, " then ", xxx.type:string);
var yyy = testGenericCast(a, ti.type, tti.type);
writeln("testGenericCast: when ", a:string, " ", ti.type:string, " ",
tti.type:string, " then ", yyy.type:string);
testExpression(a, ti.type, tti.type);
testParamExpression(a, ti.type, tti.type);
var zzz = testParamGenericCast(a, ti.type, tti.type);
writeln("testParamGenericCast: when ", a:string, " ", ti.type:string, " ",
tti.type:string, " then ", zzz.type:string);
}
for tu in Tu {
testExpression(a, ti.type, tu.type);
testExpression(a, tu.type, ti.type);
testParamExpression(a, ti.type, tu.type);
testParamExpression(a, tu.type, ti.type);
var x = testGeneric(a, ti.type, tu.type);
writeln("testGeneric: when ", a:string, " ", ti.type:string, " ",
tu.type:string, " then ", x.type:string);
var xx = testGeneric(a, tu.type, ti.type);
writeln("testGeneric: when ", a:string, " ", tu.type:string, " ",
ti.type:string, " then " , xx.type:string);

var y = testGenericCast(a, ti.type, tu.type);
writeln("testGenericCast: when ", a:string, " ", ti.type:string, " ",
tu.type:string, " then " , y.type:string);
var yy = testGenericCast(a, tu.type, ti.type);
writeln("testGenericCast: when ", a:string, " ", tu.type:string, " ",
ti.type:string, " then " , yy.type:string);

var z = testParamGenericCast(a, ti.type, tu.type);
writeln("testParamGenericCast: when ", a:string, " ", ti.type:string, " ",
tu.type:string, " then ", z.type:string);
var zz = testParamGenericCast(a, tu.type, ti.type);
writeln("testParamGenericCast: when ", a:string, " ", tu.type:string, " ",
ti.type:string, " then ", zz.type:string);
}
}
for tz in Tu {
for tzz in Tu {
var xxx = testGeneric(a, tz.type, tzz.type);
writeln("testGeneric: when ", a:string, " ", tz.type:string, " ",
tzz.type:string, " then " , xxx.type:string);
var yyy = testGenericCast(a, tz.type, tzz.type);
writeln("testGenericCast: when ", a:string, " ", tz.type:string, " ",
tzz.type:string, " then " , yyy.type:string);
testExpression(a, tz.type, tzz.type);
testParamExpression(a, tz.type, tzz.type);
var zzz = testParamGenericCast(a, tz.type, tzz.type);
writeln("testParamGenericCast: when ", a:string, " ", tz.type:string, " ",
tzz.type:string, " then ", zzz.type:string);
}
}
}
Loading
Loading