diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/MetalsGlobal.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/MetalsGlobal.scala index 0eaf8b3b5ad..8c634fc312f 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/MetalsGlobal.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/MetalsGlobal.scala @@ -791,6 +791,7 @@ class MetalsGlobal( */ def namePosition: Position = { sel match { + case _ if !sel.pos.isRange => sel.pos case Select(qualifier: Select, name) if (name == nme.apply || name == nme.unapply) && sel.pos.point == qualifier.pos.point => qualifier.namePosition diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala index 9b9165cf77d..ca19ba1eb85 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala @@ -8,6 +8,8 @@ import scala.meta.pc.VirtualFileParams trait PcCollector[T] { self: WithCompilationUnit => import compiler._ + protected def allowZeroExtentImplicits = false + def collect( parent: Option[Tree] )(tree: Tree, pos: Position, sym: Option[Symbol]): T @@ -87,6 +89,9 @@ trait PcCollector[T] { self: WithCompilationUnit => traverseSought(noTreeFilter, noSoughtFilter) } + def isCorrectPos(t: Tree): Boolean = + t.pos.isRange || (t.pos.isDefined && allowZeroExtentImplicits && t.symbol.isImplicit) + def traverseSought( filter: Tree => Boolean, soughtFilter: (Symbol => Boolean) => Boolean @@ -107,7 +112,7 @@ trait PcCollector[T] { self: WithCompilationUnit => * All indentifiers such as: * val a = <> */ - case ident: Ident if ident.pos.isRange && filter(ident) => + case ident: Ident if isCorrectPos(ident) && filter(ident) => if (ident.symbol == NoSymbol) acc + collect(ident, ident.pos, fallbackSymbol(ident.name, pos)) else @@ -118,7 +123,7 @@ trait PcCollector[T] { self: WithCompilationUnit => * type A = [<>] */ case tpe: TypeTree - if tpe.pos.isRange && tpe.original != null && filter(tpe) => + if isCorrectPos(tpe) && tpe.original != null && filter(tpe) => tpe.original.children.foldLeft( acc + collect( tpe.original, @@ -130,7 +135,7 @@ trait PcCollector[T] { self: WithCompilationUnit => * All select statements such as: * val a = hello.<> */ - case sel: Select if sel.pos.isRange && filter(sel) => + case sel: Select if isCorrectPos(sel) && filter(sel) => val newAcc = if (isForComprehensionMethod(sel)) acc else acc + collect(sel, sel.namePosition) @@ -146,7 +151,7 @@ trait PcCollector[T] { self: WithCompilationUnit => */ case Function(params, body) => val newAcc = params - .filter(vd => vd.pos.isRange && filter(vd)) + .filter(vd => isCorrectPos(vd) && filter(vd)) .foldLeft( acc ) { case (acc, vd) => @@ -167,7 +172,7 @@ trait PcCollector[T] { self: WithCompilationUnit => * class <> = ??? * etc. */ - case df: MemberDef if df.pos.isRange && filter(df) => + case df: MemberDef if isCorrectPos(df) && filter(df) => (annotationChildren(df) ++ df.children).foldLeft({ val t = collect( df, @@ -277,7 +282,7 @@ trait PcCollector[T] { self: WithCompilationUnit => res // catch all missed named trees case name: NameTree - if soughtFilter(_ == name.symbol) && name.pos.isRange => + if soughtFilter(_ == name.symbol) && isCorrectPos(name) => tree.children.foldLeft( acc + collect( name, diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala index 0cd4f35d132..06d0df536c1 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/PcReferencesProvider.scala @@ -14,6 +14,7 @@ trait PcReferencesProvider { import compiler._ protected def includeDefinition: Boolean protected def result(): List[(String, Option[l.Range])] + override def allowZeroExtentImplicits = true def collect( parent: Option[Tree] @@ -27,7 +28,10 @@ trait PcReferencesProvider { case t: DefTree if !includeDefinition => (compiler.semanticdbSymbol(sym.getOrElse(t.symbol)), None) case t => - (compiler.semanticdbSymbol(sym.getOrElse(t.symbol)), Some(pos.toLsp)) + ( + compiler.semanticdbSymbol(sym.getOrElse(t.symbol)), + Some(pos.toLsp) + ) } } diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala index c9b077cd32a..a86e7bff308 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcCollector.scala @@ -25,6 +25,8 @@ import dotty.tools.dotc.util.Spans.Span trait PcCollector[T]: self: WithCompilationUnit => + def allowZeroExtentImplicits: Boolean = false + def collect( parent: Option[Tree] )(tree: Tree | EndMarker, pos: SourcePosition, symbol: Option[Symbol]): T @@ -81,6 +83,10 @@ trait PcCollector[T]: def isCorrect = !span.isZeroExtent && span.exists && span.start < sourceText.size && span.end <= sourceText.size + extension (tree: Tree) + def isCorrectSpan = + tree.span.isCorrect || (allowZeroExtentImplicits && tree.symbol.is(Flags.Implicit)) + def traverseSought( filter: Tree => Boolean, soughtFilter: (Symbol => Boolean) => Boolean, @@ -102,7 +108,7 @@ trait PcCollector[T]: * val a = <> */ case ident: Ident - if ident.span.isCorrect && filter(ident) && !isExtensionMethodCall( + if ident.isCorrectSpan && filter(ident) && !isExtensionMethodCall( parent, ident.symbol, ) => @@ -121,7 +127,7 @@ trait PcCollector[T]: * val x = new <>(1) */ case sel @ Select(New(t), _) - if sel.span.isCorrect && + if sel.isCorrectSpan && sel.symbol.isConstructor && t.symbol == NoSymbol => if soughtFilter(_ == sel.symbol.owner) then @@ -136,7 +142,7 @@ trait PcCollector[T]: * val a = hello.<> */ case sel: Select - if sel.span.isCorrect && filter(sel) && + if sel.isCorrectSpan && filter(sel) && !sel.isForComprehensionMethod => occurences + collect( sel, diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala index 220523282fc..1fc14ada5e3 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/PcReferencesProvider.scala @@ -19,6 +19,7 @@ class PcReferencesProvider( driver: InteractiveDriver, request: ReferencesRequest, ) extends WithCompilationUnit(driver, request.file()) with PcCollector[Option[(String, Option[lsp4j.Range])]]: + override def allowZeroExtentImplicits: Boolean = true private def soughtSymbols = if(request.offsetOrSymbol().isLeft()) { diff --git a/tests/cross/src/main/scala/tests/BaseDocumentHighlightSuite.scala b/tests/cross/src/main/scala/tests/BaseDocumentHighlightSuite.scala index bbc82345aa2..b6700110d4c 100644 --- a/tests/cross/src/main/scala/tests/BaseDocumentHighlightSuite.scala +++ b/tests/cross/src/main/scala/tests/BaseDocumentHighlightSuite.scala @@ -15,7 +15,6 @@ class BaseDocumentHighlightSuite extends BasePCSuite with RangeReplace { def check( name: TestOptions, original: String - // compat: Map[String, String] = Map.empty )(implicit location: Location): Unit = test(name) { diff --git a/tests/cross/src/test/scala/tests/pc/PcReferencesSuite.scala b/tests/cross/src/test/scala/tests/pc/PcReferencesSuite.scala new file mode 100644 index 00000000000..c51dd94c06d --- /dev/null +++ b/tests/cross/src/test/scala/tests/pc/PcReferencesSuite.scala @@ -0,0 +1,135 @@ +package tests.pc + +import java.net.URI + +import scala.meta.internal.jdk.CollectionConverters._ +import scala.meta.internal.metals.CompilerVirtualFileParams +import scala.meta.internal.metals.EmptyCancelToken +import scala.meta.internal.pc.PcReferencesRequest + +import munit.Location +import munit.TestOptions +import org.eclipse.lsp4j.jsonrpc.messages.{Either => JEither} +import tests.BasePCSuite +import tests.RangeReplace + +class PcReferencesSuite extends BasePCSuite with RangeReplace { + def check( + name: TestOptions, + original: String, + compat: Map[String, String] = Map.empty + )(implicit location: Location): Unit = + test(name) { + val edit = original.replaceAll("(<<|>>)", "") + val expected = original.replaceAll("@@", "") + val base = original.replaceAll("(<<|>>|@@)", "") + + val (code, offset) = params(edit, "Highlight.scala") + val ranges = presentationCompiler + .references( + PcReferencesRequest( + CompilerVirtualFileParams( + URI.create("file:/Highlight.scala"), + code, + EmptyCancelToken + ), + includeDefinition = false, + offsetOrSymbol = JEither.forLeft(offset) + ) + ) + .get() + .asScala + .flatMap(_.locations().asScala.map(_.getRange())) + .toList + + assertEquals( + renderRangesAsString(base, ranges), + getExpected(expected, compat, scalaVersion) + ) + } + + check( + "implicit-args", + """|package example + | + |class Bar(i: Int) + | + |object Hello { + | def m(i: Int)(implicit b: Bar) = ??? + | val foo = { + | implicit val b@@arr: Bar = new Bar(1) + | m<<>>(3) + | } + |} + |""".stripMargin, + compat = Map("3" -> """|package example + | + |class Bar(i: Int) + | + |object Hello { + | def m(i: Int)(implicit b: Bar) = ??? + | val foo = { + | implicit val barr: Bar = new Bar(1) + | m(3)<<>> + | } + |} + |""".stripMargin) + ) + + check( + "implicit-args-2", + """|package example + | + |class Bar(i: Int) + |class Foo(implicit b: Bar) + | + |object Hello { + | implicit val b@@arr: Bar = new Bar(1) + | val foo = <<>>new Foo + |} + |""".stripMargin, + compat = Map( + "3" -> """|package example + | + |class Bar(i: Int) + |class Foo(implicit b: Bar) + | + |object Hello { + | implicit val barr: Bar = new Bar(1) + | val foo = new Foo<<>> + |} + |""".stripMargin + ) + ) + + // for Scala 3 the symbol in bar reference is missing () + check( + "implicit-args-3".tag(IgnoreScala3), + """|package example + | + |class Bar(i: Int) + |class Foo(implicit b: Bar) + | + |object Hello { + | implicit val b@@arr = new Bar(1) + | for { + | _ <- Some(1) + | foo = <<>>new Foo() + | } yield () + |} + |""".stripMargin + ) + + check( + "case-class", + """|case class Ma@@in(i: Int) + |""".stripMargin + ) + + check( + "case-class-with-implicit", + """"|case class A()(implicit val fo@@o: Int) + |""".stripMargin + ) + +}