diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 815c88ceb465..c57ced36f2d8 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1929,15 +1929,44 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer NoType } - pt.stripNull() match { - case pt: TypeVar - if untpd.isFunctionWithUnknownParamType(tree) && !calleeType.exists => - // try to instantiate `pt` if this is possible. If it does not - // work the error will be reported later in `inferredParam`, - // when we try to infer the parameter type. - isFullyDefined(pt, ForceDegree.flipBottom) - case _ => - } + /** Try to instantiate one type variable bounded by function types that appear + * deeply inside `tp`, including union or intersection types. + */ + def tryToInstantiateDeeply(tp: Type): Boolean = tp.dealias match + case tp: AndOrType => + tryToInstantiateDeeply(tp.tp1) + || tryToInstantiateDeeply(tp.tp2) + case tp: FlexibleType => + tryToInstantiateDeeply(tp.hi) + case tp: TypeVar if isConstrainedByFunctionType(tp) => + // Only instantiate if the type variable is constrained by function types + isFullyDefined(tp, ForceDegree.flipBottom) + case _ => false + + def isConstrainedByFunctionType(tvar: TypeVar): Boolean = + val origin = tvar.origin + val bounds = ctx.typerState.constraint.bounds(origin) + // The search is done by the best-effort, and we don't look into TypeVars recursively. + def containsFunctionType(tp: Type): Boolean = tp.dealias match + case tp if defn.isFunctionType(tp) => true + case SAMType(_, _) => true + case tp: AndOrType => + containsFunctionType(tp.tp1) || containsFunctionType(tp.tp2) + case tp: FlexibleType => + containsFunctionType(tp.hi) + case _ => false + containsFunctionType(bounds.lo) || containsFunctionType(bounds.hi) + + if untpd.isFunctionWithUnknownParamType(tree) && !calleeType.exists then + // Try to instantiate `pt` when possible. + // * If `pt` is a type variable, we try to instantiate it directly. + // * If `pt` is a more complex type, we try to instantiate it deeply by searching + // a nested type variable bounded by a function type to help infer parameter types. + // If it does not work the error will be reported later in `inferredParam`, + // when we try to infer the parameter type. + pt match + case pt: TypeVar => isFullyDefined(pt, ForceDegree.flipBottom) + case _ => tryToInstantiateDeeply(pt) val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos) diff --git a/tests/pos/infer-function-type-in-union.scala b/tests/pos/infer-function-type-in-union.scala new file mode 100644 index 000000000000..f631761b3897 --- /dev/null +++ b/tests/pos/infer-function-type-in-union.scala @@ -0,0 +1,42 @@ + +def f[T](x: T): T = ??? +def f2[T](x: T | T): T = ??? +def f3[T](x: T | Null): T = ??? +def f4[T](x: Int | T): T = ??? + +trait MyOption[+T] + +object MyOption: + def apply[T](x: T | Null): MyOption[T] = ??? + +def test = + val g: AnyRef => Boolean = f { + x => x eq null // ok + } + val g2: AnyRef => Boolean = f2 { + x => x eq null // ok + } + val g3: AnyRef => Boolean = f3 { + x => x eq null // was error + } + val g4: AnyRef => Boolean = f4 { + x => x eq null // was error + } + + val o1: MyOption[String] = MyOption(null) + val o2: MyOption[String => Boolean] = MyOption { + x => x.length > 0 + } + val o3: MyOption[(String, String) => Boolean] = MyOption { + (x, y) => x.length > y.length + } + + +class Box[T] +val box: Box[Unit] = ??? +def ff1[T, U](x: T | U, y: Box[U]): T = ??? +def ff2[T, U](x: T & U): T = ??? + +def test2 = + val a1: Any => Any = ff1(x => x, box) + val a2: Any => Any = ff2(x => x) \ No newline at end of file