Skip to content

Commit

Permalink
Improved return inspection handling for unions of multiple results
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin-Dobell committed Dec 27, 2020
1 parent 67c9c5a commit 0590153
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,83 +43,117 @@ class ReturnTypeInspection : StrictInspection() {

val context = PsiSearchContext(o)
val bodyOwner = PsiTreeUtil.getParentOfType(o, LuaFuncBodyOwner::class.java) ?: return
val expectedReturnType = if (bodyOwner is LuaClassMethodDefStat) {
val functionReturnDocTy = if (bodyOwner is LuaClassMethodDefStat) {
guessSuperReturnTypes(bodyOwner, context)
} else {
bodyOwner.tagReturn?.type
} ?: TyMultipleResults(listOf(Ty.UNKNOWN), true)

val concreteType = context.withMultipleResults {
val concreteTy = context.withMultipleResults {
o.exprList?.guessType(context)?.let {
TyMultipleResults.flatten(context, it)
}
} ?: Ty.VOID

val concreteTypes = toList(concreteType)
val concreteTys = toList(concreteTy)

val documentedReturnTypeTag = o.comment?.let { PsiTreeUtil.getChildrenOfTypeAsList(it, LuaDocTagTypeImpl::class.java).firstOrNull() }
val documentedType = documentedReturnTypeTag?.getType()
val statementDocTagType = o.comment?.let { PsiTreeUtil.getChildrenOfTypeAsList(it, LuaDocTagTypeImpl::class.java).firstOrNull() }
val statementDocTy = statementDocTagType?.getType()

val abstractTypes = toList(documentedType ?: expectedReturnType)
val variadicAbstractType = if (expectedReturnType is TyMultipleResults && expectedReturnType.variadic) {
expectedReturnType.list.last()
} else null
val processCandidate = fun(candidateReturnTy: ITy): Collection<Problem> {
val problems = mutableListOf<Problem>()

for (i in 0 until concreteTypes.size) {
val element = o.exprList?.getExpressionAt(i) ?: o
val targetType = abstractTypes.getOrNull(i) ?: variadicAbstractType ?: Ty.VOID
val varianceFlags = if (element is LuaTableExpr) TyVarianceFlags.WIDEN_TABLES else 0
val abstractTys = toList(statementDocTy ?: candidateReturnTy)
val variadicAbstractType = if (candidateReturnTy is TyMultipleResults && candidateReturnTy.variadic) {
candidateReturnTy.list.last()
} else null

for (i in 0 until concreteTys.size) {
val element = o.exprList?.getExpressionAt(i) ?: o
val targetType = abstractTys.getOrNull(i) ?: variadicAbstractType ?: Ty.VOID
val varianceFlags = if (element is LuaTableExpr) TyVarianceFlags.WIDEN_TABLES else 0

ProblemUtil.contravariantOf(targetType, concreteTys[i], context, varianceFlags, null, element) { problem ->
val targetMessage = problem.message

ProblemUtil.contravariantOf(targetType, concreteTypes[i], context, varianceFlags, null, element) { problem ->
val sourceElement = problem.sourceElement
val targetElement = problem.targetElement
val sourceMessage = if (concreteTypes.size > 1) "Result ${i + 1}, ${problem.message.decapitalize()}" else problem.message
val highlightType = problem.highlightType ?: ProblemHighlightType.GENERIC_ERROR_OR_WARNING
if (concreteTys.size > 1) {
problem.message = "Result ${i + 1}, ${targetMessage.decapitalize()}"
}

myHolder.registerProblem(sourceElement, sourceMessage, highlightType)
problems.add(problem)

if (targetElement != null && targetElement != sourceElement) {
myHolder.registerProblem(targetElement, problem.message, highlightType)
if (problem.targetElement != null && problem.targetElement != problem.sourceElement) {
problems.add(Problem(null, problem.targetElement, targetMessage, problem.highlightType))
}
}
}
}

val abstractReturnCount = if (variadicAbstractType != null) {
abstractTypes.size - 1
} else abstractTypes.size
val abstractReturnCount = if (variadicAbstractType != null) {
abstractTys.size - 1
} else abstractTys.size

val concreteReturnCount = if (concreteType is TyMultipleResults && concreteType.variadic) {
concreteTypes.size - 1
} else concreteTypes.size
val concreteReturnCount = if (concreteTy is TyMultipleResults && concreteTy.variadic) {
concreteTys.size - 1
} else concreteTys.size

if (concreteReturnCount < abstractReturnCount) {
myHolder.registerProblem(o.lastChild, "Incorrect number of values. Expected %s but found %s.".format(abstractReturnCount, concreteReturnCount))
}
if (concreteReturnCount < abstractReturnCount) {
problems.add(Problem(null, o.lastChild, "Incorrect number of values. Expected %s but found %s.".format(abstractReturnCount, concreteReturnCount)))
}

if (documentedType != null) {
val expectedReturnTypes = toList(expectedReturnType)
val variadicExpectedReturnType = if (expectedReturnType is TyMultipleResults && expectedReturnType.variadic) {
expectedReturnType.list.last()
} else null
if (statementDocTy != null) {
val expectedReturnTys = toList(candidateReturnTy)
val expectedVariadicReturnTy = if (candidateReturnTy is TyMultipleResults && candidateReturnTy.variadic) {
candidateReturnTy.list.last()
} else null

for (i in 0 until abstractTypes.size) {
val targetType = expectedReturnTypes.getOrNull(i) ?: variadicExpectedReturnType ?: Ty.VOID
for (i in 0 until abstractTys.size) {
val targetType = expectedReturnTys.getOrNull(i) ?: expectedVariadicReturnTy ?: Ty.VOID

if (!targetType.contravariantOf(abstractTypes[i], context, 0)) {
val element = documentedReturnTypeTag.typeList?.tyList?.let { it.getOrNull(i) ?: it.last() } ?: documentedReturnTypeTag
val message = "Type mismatch. Required: '%s' Found: '%s'".format(targetType.displayName, abstractTypes[i].displayName)
myHolder.registerProblem(element, message)
if (!targetType.contravariantOf(abstractTys[i], context, 0)) {
val element = statementDocTagType.typeList?.tyList?.let { it.getOrNull(i) ?: it.last() } ?: statementDocTagType
val message = "Type mismatch. Required: '%s' Found: '%s'".format(targetType.displayName, abstractTys[i].displayName)
problems.add(Problem(null, element, message))
}
}

val candidateReturnCount = if (expectedVariadicReturnTy != null) {
expectedReturnTys.size - 1
} else expectedReturnTys.size

if (abstractReturnCount < candidateReturnCount) {
val element = statementDocTagType.typeList ?: statementDocTagType
val message = "Incorrect number of values. Expected %s but found %s.".format(candidateReturnCount, abstractReturnCount)
problems.add(Problem(null, element, message))
}
}

val expectedReturnCount = if (variadicExpectedReturnType != null) {
expectedReturnTypes.size - 1
} else expectedReturnTypes.size
return problems
}

if (abstractReturnCount < expectedReturnCount) {
val element = documentedReturnTypeTag.typeList ?: documentedReturnTypeTag
val message = "Incorrect number of values. Expected %s but found %s.".format(expectedReturnCount, abstractReturnCount)
myHolder.registerProblem(element, message)
val multipleCandidates = functionReturnDocTy is TyUnion && functionReturnDocTy.getChildTypes().any { it is TyMultipleResults }

if (multipleCandidates) {
val candidateProblems = mutableMapOf<String, Collection<Problem>>()

TyUnion.each(functionReturnDocTy) {
val problems = processCandidate(it)

if (problems.size == 0) {
return
}

candidateProblems.put(it.displayName, problems)
}

candidateProblems.forEach { candidate, problems ->
problems.forEach {
val message = "${it.message} for candidate return type (${candidate})"
myHolder.registerProblem(it.sourceElement, message, it.highlightType ?: ProblemHighlightType.GENERIC_ERROR_OR_WARNING)
}
}
} else {
processCandidate(functionReturnDocTy).forEach {
myHolder.registerProblem(it.sourceElement, it.message, it.highlightType ?: ProblemHighlightType.GENERIC_ERROR_OR_WARNING)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/tang/intellij/lua/ty/Ty.kt
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ class TyUnknown : Ty(TyKind.Unknown) {
}

override fun contravariantOf(other: ITy, context: SearchContext, flags: Int): Boolean {
return true
return other !is TyMultipleResults
}

override fun guessMemberType(name: String, searchContext: SearchContext): ITy? {
Expand Down
12 changes: 12 additions & 0 deletions src/test/resources/inspections/function_multiple_returns.lua
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,15 @@ local inferredNumberOrString = <weak_warning descr="Insufficient assignees, valu

numberOrString = inferredNumberOrString
aNumber = <error descr="Type mismatch. Required: 'number' Found: '1 | string'">inferredNumberOrString</error> -- Expect error

---@param val number
---@return (string, number) | (number, function, string)
local function returnUnionOfMultipleResults(val)
if val > 0 then
return aString, aNumber
elseif val < 0 then
return aNumber, function() end, aString
end

return <error descr="Incorrect number of values. Expected 3 but found 2. for candidate return type (number, function, string)"><error descr="Result 1, type mismatch. Required: 'string' Found: 'number' for candidate return type (string, number)">aNumber</error>, <error descr="Result 2, type mismatch. Required: 'function' Found: 'number' for candidate return type (number, function, string)">aNumber</error></error>
end

0 comments on commit 0590153

Please sign in to comment.