Skip to content

Commit

Permalink
fix: show zero extent references when using pc
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Jul 11, 2024
1 parent 905fe0e commit e183a61
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ class MetalsGlobal(
*/
def namePosition: Position = {
sel match {
case _ if !sel.pos.isRange => sel.pos
case Select(qualifier: Select, name)
if (name == nme.apply || name == nme.unapply) && sel.pos.point == qualifier.pos.point =>
qualifier.namePosition
Expand Down
17 changes: 11 additions & 6 deletions mtags/src/main/scala-2/scala/meta/internal/pc/PcCollector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import scala.meta.pc.VirtualFileParams
trait PcCollector[T] { self: WithCompilationUnit =>
import compiler._

protected def allowZeroExtent = false

def collect(
parent: Option[Tree]
)(tree: Tree, pos: Position, sym: Option[Symbol]): T
Expand Down Expand Up @@ -87,6 +89,9 @@ trait PcCollector[T] { self: WithCompilationUnit =>
traverseSought(noTreeFilter, noSoughtFilter)
}

def isCorrect(p: Position): Boolean =
p.isRange || (p.isDefined && allowZeroExtent)

def traverseSought(
filter: Tree => Boolean,
soughtFilter: (Symbol => Boolean) => Boolean
Expand All @@ -107,7 +112,7 @@ trait PcCollector[T] { self: WithCompilationUnit =>
* All indentifiers such as:
* val a = <<b>>
*/
case ident: Ident if ident.pos.isRange && filter(ident) =>
case ident: Ident if isCorrect(ident.pos) && filter(ident) =>
if (ident.symbol == NoSymbol)
acc + collect(ident, ident.pos, fallbackSymbol(ident.name, pos))
else
Expand All @@ -118,7 +123,7 @@ trait PcCollector[T] { self: WithCompilationUnit =>
* type A = [<<b>>]
*/
case tpe: TypeTree
if tpe.pos.isRange && tpe.original != null && filter(tpe) =>
if isCorrect(tpe.pos) && tpe.original != null && filter(tpe) =>
tpe.original.children.foldLeft(
acc + collect(
tpe.original,
Expand All @@ -130,7 +135,7 @@ trait PcCollector[T] { self: WithCompilationUnit =>
* All select statements such as:
* val a = hello.<<b>>
*/
case sel: Select if sel.pos.isRange && filter(sel) =>
case sel: Select if isCorrect(sel.pos) && filter(sel) =>
val newAcc =
if (isForComprehensionMethod(sel)) acc
else acc + collect(sel, sel.namePosition)
Expand All @@ -146,7 +151,7 @@ trait PcCollector[T] { self: WithCompilationUnit =>
*/
case Function(params, body) =>
val newAcc = params
.filter(vd => vd.pos.isRange && filter(vd))
.filter(vd => isCorrect(vd.pos) && filter(vd))
.foldLeft(
acc
) { case (acc, vd) =>
Expand All @@ -167,7 +172,7 @@ trait PcCollector[T] { self: WithCompilationUnit =>
* class <<Foo>> = ???
* etc.
*/
case df: MemberDef if df.pos.isRange && filter(df) =>
case df: MemberDef if isCorrect(df.pos) && filter(df) =>
(annotationChildren(df) ++ df.children).foldLeft({
val t = collect(
df,
Expand Down Expand Up @@ -277,7 +282,7 @@ trait PcCollector[T] { self: WithCompilationUnit =>
res
// catch all missed named trees
case name: NameTree
if soughtFilter(_ == name.symbol) && name.pos.isRange =>
if soughtFilter(_ == name.symbol) && isCorrect(name.pos) =>
tree.children.foldLeft(
acc + collect(
name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ trait PcReferencesProvider {
import compiler._
protected def includeDefinition: Boolean
protected def result(): List[(String, Option[l.Range])]
override def allowZeroExtent = true

def collect(
parent: Option[Tree]
Expand All @@ -27,7 +28,15 @@ trait PcReferencesProvider {
case t: DefTree if !includeDefinition =>
(compiler.semanticdbSymbol(sym.getOrElse(t.symbol)), None)
case t =>
(compiler.semanticdbSymbol(sym.getOrElse(t.symbol)), Some(pos.toLsp))
parent match {
case Some(p: DefTree) if p.symbol.isAccessor =>
(compiler.semanticdbSymbol(sym.getOrElse(t.symbol)), None)
case _ =>
(
compiler.semanticdbSymbol(sym.getOrElse(t.symbol)),
Some(pos.toLsp)
)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import dotty.tools.dotc.util.Spans.Span
trait PcCollector[T]:
self: WithCompilationUnit =>

def allowZeroExtent: Boolean = false

def collect(
parent: Option[Tree]
)(tree: Tree | EndMarker, pos: SourcePosition, symbol: Option[Symbol]): T
Expand Down Expand Up @@ -79,7 +81,7 @@ trait PcCollector[T]:

extension (span: Span)
def isCorrect =
!span.isZeroExtent && span.exists && span.start < sourceText.size && span.end <= sourceText.size
(allowZeroExtent || !span.isZeroExtent) && span.exists && span.start < sourceText.size && span.end <= sourceText.size

def traverseSought(
filter: Tree => Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class PcReferencesProvider(
driver: InteractiveDriver,
request: ReferencesRequest,
) extends WithCompilationUnit(driver, request.file()) with PcCollector[Option[(String, Option[lsp4j.Range])]]:
override def allowZeroExtent: Boolean = true

private def soughtSymbols =
if(request.offsetOrSymbol().isLeft()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class BaseDocumentHighlightSuite extends BasePCSuite with RangeReplace {
def check(
name: TestOptions,
original: String
// compat: Map[String, String] = Map.empty
)(implicit location: Location): Unit =
test(name) {

Expand Down
123 changes: 123 additions & 0 deletions tests/cross/src/test/scala/tests/pc/PcReferencesSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package tests.pc

import java.net.URI

import scala.meta.internal.jdk.CollectionConverters._
import scala.meta.internal.metals.CompilerVirtualFileParams
import scala.meta.internal.metals.EmptyCancelToken
import scala.meta.internal.pc.PcReferencesRequest

import munit.Location
import munit.TestOptions
import org.eclipse.lsp4j.jsonrpc.messages.{Either => JEither}
import tests.BasePCSuite
import tests.RangeReplace

class PcReferencesSuite extends BasePCSuite with RangeReplace {
def check(
name: TestOptions,
original: String,
compat: Map[String, String] = Map.empty
)(implicit location: Location): Unit =
test(name) {
val edit = original.replaceAll("(<<|>>)", "")
val expected = original.replaceAll("@@", "")
val base = original.replaceAll("(<<|>>|@@)", "")

val (code, offset) = params(edit, "Highlight.scala")
val ranges = presentationCompiler
.references(
PcReferencesRequest(
CompilerVirtualFileParams(
URI.create("file:/Highlight.scala"),
code,
EmptyCancelToken
),
includeDefinition = false,
offsetOrSymbol = JEither.forLeft(offset)
)
)
.get()
.asScala
.flatMap(_.locations().asScala.map(_.getRange()))
.toList

assertEquals(
renderRangesAsString(base, ranges),
getExpected(expected, compat, scalaVersion)
)
}

check(
"implicit-args",
"""|package example
|
|class Bar(i: Int)
|
|object Hello {
| def m(i: Int)(implicit b: Bar) = ???
| val foo = {
| implicit val b@@arr: Bar = new Bar(1)
| m<<>>(3)
| }
|}
|""".stripMargin,
compat = Map("3" -> """|package example
|
|class Bar(i: Int)
|
|object Hello {
| def m(i: Int)(implicit b: Bar) = ???
| val foo = {
| implicit val barr: Bar = new Bar(1)
| m(3)<<>>
| }
|}
|""".stripMargin)
)

check(
"implicit-args-2",
"""|package example
|
|class Bar(i: Int)
|class Foo(implicit b: Bar)
|
|object Hello {
| implicit val b@@arr: Bar = new Bar(1)
| val foo = <<>>new Foo
|}
|""".stripMargin,
compat = Map(
"3" -> """|package example
|
|class Bar(i: Int)
|class Foo(implicit b: Bar)
|
|object Hello {
| implicit val barr: Bar = new Bar(1)
| val foo = new Foo<<>>
|}
|""".stripMargin
)
)

// for Scala 3 the symbol in bar reference is missing (<none>)
check(
"implicit-args-3".tag(IgnoreScala3),
"""|package example
|
|class Bar(i: Int)
|class Foo(implicit b: Bar)
|
|object Hello {
| implicit val b@@arr = new Bar(1)
| for {
| _ <- Some(1)
| foo = <<>>new Foo()
| } yield ()
|}
|""".stripMargin
)

}

0 comments on commit e183a61

Please sign in to comment.