Skip to content

Commit 2781a7e

Browse files
committed
Fix poorly defined behavior when choosing certain function overloads.
Previously the order of overload declarations would affect which was chosen.
1 parent cb9935c commit 2781a7e

File tree

12 files changed

+477
-92
lines changed

12 files changed

+477
-92
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ TESTSUITE ( and-or-not-synonyms aastep arithmetic array array-derivs array-range
248248
error-dupes exit exponential
249249
fprintf
250250
function-earlyreturn function-simple function-outputelem
251+
function-overloads
251252
geomath getattribute-camera getattribute-shader
252253
getsymbol-nonheap gettextureinfo
253254
group-outputs groupstring

src/liboslcomp/ast.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -889,17 +889,6 @@ class ASTfunction_call : public ASTNode
889889
}
890890

891891
private:
892-
/// Typecheck all polymorphic versions, return UNKNOWN if no match was
893-
/// found, or a real type if there was a match. Also, upon matching,
894-
/// re-jigger m_sym to point to the specific polymorphic match.
895-
/// Allow arguments to be coerced (e.g., substituting a vector where
896-
/// a point was expected, or a float where a color was expected) only
897-
/// if coerceargs is true. For return values, allow spatial triples to
898-
/// mutually match if 'equivreturn' is true, and allow any coercive
899-
/// return type if 'expected' is TypeSpec() (i.e., unknown).
900-
TypeSpec typecheck_all_poly (TypeSpec expected, bool coerceargs,
901-
bool equivreturn);
902-
903892
/// Handle all the special cases for built-ins. This includes
904893
/// irregular patterns of which args are read vs written, special
905894
/// checks for printf- and texture-like, etc.

src/liboslcomp/typecheck.cpp

Lines changed: 228 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -903,28 +903,229 @@ ASTNode::check_arglist (const char *funcname, ASTNode::ref arg,
903903
}
904904

905905

906+
class CandidateFunctions {
907+
enum {
908+
kExactMatch = 100,
909+
kIntegralToFP = 80,
910+
kArrayMatch = 40,
911+
kCoercable = 20,
912+
kMatchAnything = 1,
913+
kNoMatch = 0,
914+
915+
// Additional rules that don't match C++ behaviour
916+
kFPToIntegral = 60, // = kIntegralToFP to match c++
917+
kMatchReturn = kExactMatch, // = 0 to match c++
918+
kCoercehReturn = kCoercable, // = 0 to match c++
919+
};
920+
struct Candidate {
921+
FunctionSymbol* sym;
922+
TypeSpec rtype;
923+
int ascore;
924+
int rscore;
925+
926+
Candidate(FunctionSymbol *s, TypeSpec rt, int as, int rs) :
927+
sym(s), rtype(rt), ascore(as), rscore(rs) {}
928+
929+
string_view name() const { return sym->name(); }
930+
};
931+
typedef std::vector<Candidate> Candidates;
932+
933+
OSLCompilerImpl* m_compiler;
934+
Candidates m_candidates;
935+
TypeSpec m_rval;
936+
ASTNode::ref m_args;
937+
size_t m_nargs;
938+
939+
const char* scoreWildcard(int& argscore, size_t& fargs, const char* args) const {
940+
while (fargs < m_nargs) {
941+
argscore += kMatchAnything;
942+
++fargs;
943+
}
944+
return args + 1;
945+
}
906946

907-
TypeSpec
908-
ASTfunction_call::typecheck_all_poly (TypeSpec expected, bool coerceargs,
909-
bool equivreturn)
910-
{
911-
for (FunctionSymbol *poly = func(); poly; poly = poly->nextpoly()) {
912-
const char *code = poly->argcodes().c_str();
947+
int addCandidate(FunctionSymbol* func) {
913948
int advance;
914-
TypeSpec returntype = m_compiler->type_from_code (code, &advance);
915-
code += advance;
916-
if (check_arglist (m_name.c_str(), args(), code, coerceargs)) {
917-
// Return types also must match if not coercible
918-
if (expected == returntype ||
919-
(equivreturn && equivalent(expected,returntype)) ||
920-
expected == TypeSpec()) {
921-
m_sym = poly;
922-
return returntype;
949+
const char *formals = func->argcodes().c_str();
950+
TypeSpec rtype = m_compiler->type_from_code (formals, &advance);
951+
formals += advance;
952+
953+
int argscore = 0;
954+
size_t fargs = 0;
955+
for (ASTNode::ref arg = m_args; *formals && arg; ++fargs, arg = arg->next()) {
956+
switch (*formals) {
957+
case '*': // Will match anything left
958+
formals = scoreWildcard(argscore, fargs, formals);
959+
ASSERT (*formals == 0);
960+
continue;
961+
962+
case '.': // Token/value pairs
963+
if (arg->typespec().is_string() && arg->next()) {
964+
formals = scoreWildcard(argscore, fargs, formals);
965+
ASSERT (*formals == 0);
966+
continue;
967+
}
968+
return kNoMatch;
969+
970+
case '?':
971+
if (formals[1] == '[' && formals[2] == ']') {
972+
// Any array
973+
formals += 3;
974+
if (!arg->typespec().is_array())
975+
return kNoMatch; // wanted an array, didn't get one
976+
argscore += kMatchAnything;
977+
} else if (!arg->typespec().is_array()) {
978+
formals += 1; // match anything
979+
argscore += kMatchAnything;
980+
} else
981+
return kNoMatch; // wanted any scalar, got an array
982+
continue;
983+
984+
default:
985+
break;
923986
}
987+
// To many arguments for the function, done without a match.
988+
if (fargs >= m_nargs)
989+
return kNoMatch;
990+
991+
TypeSpec argtype = arg->typespec();
992+
TypeSpec formaltype = m_compiler->type_from_code (formals, &advance);
993+
formals += advance;
994+
995+
if (argtype == formaltype)
996+
argscore += kExactMatch; // ok, move on to next arg
997+
else if (!argtype.is_closure() && argtype.is_scalarnum() &&
998+
!formaltype.is_closure() && formaltype.is_scalarnum())
999+
argscore += formaltype.is_int() ? kFPToIntegral : kIntegralToFP;
1000+
else if (formaltype.is_unsized_array() && argtype.is_sized_array() &&
1001+
formaltype.elementtype() == argtype.elementtype()) {
1002+
// Allow a fixed-length array match to a formal array with
1003+
// unspecified length, if the element types are the same.
1004+
argscore += kArrayMatch;
1005+
} else if (assignable (formaltype, argtype))
1006+
argscore += kCoercable;
1007+
else
1008+
return kNoMatch;
9241009
}
1010+
1011+
// Check any remaining arguments
1012+
switch (*formals) {
1013+
case '*':
1014+
case '.':
1015+
// Skip over the unused optional args
1016+
++formals;
1017+
++fargs;
1018+
case '\0':
1019+
if (fargs < m_nargs)
1020+
return 0;
1021+
break;
1022+
1023+
default:
1024+
// TODO: Scoring default function arguments would go here
1025+
// Curently an unused formal argument, so no match at all.
1026+
return 0;
1027+
}
1028+
ASSERT (*formals == 0);
1029+
1030+
int highscore = m_candidates.empty() ? 0 : m_candidates.front().ascore;
1031+
if (argscore < highscore)
1032+
return 0;
1033+
1034+
1035+
if (argscore == highscore) {
1036+
// Check for duplicate declarations
1037+
for (auto& candidate : m_candidates) {
1038+
if (candidate.sym->argcodes() == func->argcodes())
1039+
return 0;
1040+
}
1041+
} else // clear any prior ambiguous matches
1042+
m_candidates.clear();
1043+
1044+
// append the latest high scoring function
1045+
m_candidates.emplace_back(func, rtype, argscore, rtype == m_rval ?
1046+
kMatchReturn : (equivalent(rtype, m_rval) ? kCoercehReturn : kNoMatch));
1047+
1048+
return argscore;
9251049
}
926-
return TypeSpec();
927-
}
1050+
1051+
public:
1052+
CandidateFunctions(OSLCompilerImpl* compiler, TypeSpec rval, ASTNode::ref args, FunctionSymbol* func) :
1053+
m_compiler(compiler), m_rval(rval), m_args(args), m_nargs(0) {
1054+
1055+
//std::cerr << "Matching " << func->name() << " formals='" << (rval.simpletype().basetype != TypeDesc::UNKNOWN ? compiler->code_from_type (rval) : " ");
1056+
for (ASTNode::ref arg = m_args; arg; arg = arg->next()) {
1057+
//std::cerr << compiler->code_from_type (arg->typespec());
1058+
++m_nargs;
1059+
}
1060+
//std::cerr << "'\n";
1061+
1062+
while (func) {
1063+
//int score =
1064+
addCandidate(func);
1065+
//std::cerr << '\t' << func->name() << " formals='" << func->argcodes().c_str() << "' " << score << ", " << (score ? m_candidates.back().rscore : 0) << "\n";
1066+
func = func->nextpoly();
1067+
}
1068+
}
1069+
1070+
void reportError(ASTfunction_call* caller, string_view name) {
1071+
std::string actualargs;
1072+
for (ASTNode::ref arg = m_args; arg; arg = arg->next()) {
1073+
if (actualargs.length())
1074+
actualargs += ", ";
1075+
actualargs += arg->typespec().string();
1076+
}
1077+
caller->error ("No matching function call to '%s (%s)'",
1078+
name.c_str(), actualargs.c_str());
1079+
}
1080+
1081+
void reportAmbiguity(FunctionSymbol* sym) const {
1082+
int advance;
1083+
const char *formals = sym->argcodes().c_str();
1084+
TypeSpec returntype = m_compiler->type_from_code (formals, &advance);
1085+
formals += advance;
1086+
1087+
auto& errh = m_compiler->errhandler();
1088+
if (ASTNode* decl = sym->node())
1089+
errh.message("%s:%d ", decl->sourcefile(), decl->sourceline());
1090+
1091+
errh.message("candidate function:\n");
1092+
errh.message("\t%s %s (%s)\n",
1093+
m_compiler->type_c_str(returntype), sym->name(),
1094+
m_compiler->typelist_from_code(formals).c_str());
1095+
}
1096+
1097+
std::pair<FunctionSymbol*, TypeSpec> best(ASTNode* caller, bool strict = 0) {
1098+
switch (m_candidates.size()) {
1099+
case 0: return { nullptr, TypeSpec() };
1100+
case 1: return { m_candidates[0].sym, m_candidates[0].rtype };
1101+
default: break;
1102+
}
1103+
1104+
int ambiguity = 0;
1105+
std::pair<const Candidate*, int> c = { nullptr, -1 };
1106+
for (auto& candidate : m_candidates) {
1107+
// re-score based on matching return value
1108+
if (candidate.rscore > c.second)
1109+
c = std::make_pair(&candidate, candidate.rscore);
1110+
else if (candidate.rscore == c.second)
1111+
ambiguity = candidate.rscore;
1112+
}
1113+
1114+
if (ambiguity || strict) {
1115+
ASSERT (caller);
1116+
caller->warning( "call to '%s' is ambiguous", m_candidates[0].name());
1117+
for (auto& candidate : m_candidates) {
1118+
if (candidate.rscore >= ambiguity)
1119+
reportAmbiguity(candidate.sym);
1120+
}
1121+
}
1122+
1123+
ASSERT (c.first);
1124+
return {c.first->sym, c.first->rtype};
1125+
}
1126+
1127+
bool empty() const { return m_candidates.empty(); }
1128+
};
9281129

9291130

9301131

@@ -1191,47 +1392,10 @@ ASTfunction_call::typecheck (TypeSpec expected)
11911392
return typecheck_struct_constructor ();
11921393
}
11931394

1194-
bool match = false;
1195-
1196-
// Look for an exact match, including expected return type
1197-
m_typespec = typecheck_all_poly (expected, false, false);
1198-
if (m_typespec != TypeSpec())
1199-
match = true;
1200-
1201-
// Now look for an exact match for arguments, but equivalent return type
1202-
m_typespec = typecheck_all_poly (expected, false, true);
1203-
if (m_typespec != TypeSpec())
1204-
match = true;
1395+
CandidateFunctions candidates(m_compiler, expected, args(), func());
1396+
std::tie(m_sym, m_typespec) = candidates.best(this);
12051397

1206-
// Now look for an exact match on args, but any return type
1207-
if (! match && expected != TypeSpec()) {
1208-
m_typespec = typecheck_all_poly (TypeSpec(), false, false);
1209-
if (m_typespec != TypeSpec())
1210-
match = true;
1211-
}
1212-
1213-
// Now look for a coercible match of args, exact march on return type
1214-
if (! match) {
1215-
m_typespec = typecheck_all_poly (expected, true, false);
1216-
if (m_typespec != TypeSpec())
1217-
match = true;
1218-
}
1219-
1220-
// Now look for a coercible match of args, equivalent march on return type
1221-
if (! match) {
1222-
m_typespec = typecheck_all_poly (expected, true, true);
1223-
if (m_typespec != TypeSpec())
1224-
match = true;
1225-
}
1226-
1227-
// All that failed, try for a coercible match on everything
1228-
if (! match && expected != TypeSpec()) {
1229-
m_typespec = typecheck_all_poly (TypeSpec(), true, false);
1230-
if (m_typespec != TypeSpec())
1231-
match = true;
1232-
}
1233-
1234-
if (match) {
1398+
if (m_sym != nullptr) {
12351399
if (is_user_function()) {
12361400
if (func()->number_of_returns() == 0 &&
12371401
! func()->typespec().is_void()) {
@@ -1245,35 +1409,18 @@ ASTfunction_call::typecheck (TypeSpec expected)
12451409
return m_typespec;
12461410
}
12471411

1412+
// Ambiguity has already been reported.
1413+
if (!candidates.empty())
1414+
return TypeSpec();
1415+
12481416
// Couldn't find any way to match any polymorphic version of the
12491417
// function that we know about. OK, at least try for helpful error
12501418
// message.
1251-
std::string choices ("");
1252-
for (FunctionSymbol *poly = func(); poly; poly = poly->nextpoly()) {
1253-
const char *code = poly->argcodes().c_str();
1254-
int advance;
1255-
TypeSpec returntype = m_compiler->type_from_code (code, &advance);
1256-
code += advance;
1257-
if (choices.length())
1258-
choices += "\n";
1259-
choices += Strutil::format ("\t%s %s (%s)",
1260-
type_c_str(returntype), m_name.c_str(),
1261-
m_compiler->typelist_from_code(code).c_str());
1262-
}
1419+
candidates.reportError(this, m_name);
12631420

1264-
std::string actualargs;
1265-
for (ASTNode::ref arg = args(); arg; arg = arg->next()) {
1266-
if (actualargs.length())
1267-
actualargs += ", ";
1268-
actualargs += arg->typespec().string();
1269-
}
1421+
for (FunctionSymbol *poly = func(); poly; poly = poly->nextpoly())
1422+
candidates.reportAmbiguity(poly);
12701423

1271-
if (choices.size())
1272-
error ("No matching function call to '%s (%s)'\n Candidates are:\n%s",
1273-
m_name.c_str(), actualargs.c_str(), choices.c_str());
1274-
else
1275-
error ("No matching function call to '%s (%s)'",
1276-
m_name.c_str(), actualargs.c_str());
12771424
return TypeSpec();
12781425
}
12791426

testsuite/function-overloads/a_fcn.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
void testA(float a, float b, float c) {
3+
printf("testA float\n");
4+
}
5+
void testA(color a, float b, float c) {
6+
printf("testA color\n");
7+
}
8+
void testA(normal a, float b, float c) {
9+
printf("testA normal\n");
10+
}
11+

testsuite/function-overloads/a_ivp.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
void testA(int a, float b, float c) {
3+
printf("testA int\n");
4+
}
5+
void testA(vector a, float b, float c) {
6+
printf("testA vector\n");
7+
}
8+
void testA(point a, float b, float c) {
9+
printf("testA point\n");
10+
}
11+

0 commit comments

Comments
 (0)