Skip to content

Commit

Permalink
bugfix: emit as extension method if member of implicit class with val…
Browse files Browse the repository at this point in the history
… in constructor (#6190)

* bugfix: emit as extension method if member of implicit class with val in constructor

* account for refined types

* delete commented out code
  • Loading branch information
kasiaMarek authored Mar 11, 2024
1 parent 151ba43 commit b2b0ac5
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ class ScalaToplevelMtags(
case _ => region
}

private def isInParenthesis(region: Region): Boolean =
region match {
case (_: Region.InParenCaseClass) | (_: Region.InParenClass) =>
true
case _ => false
}

@tailrec
private def loop(
indent: Int,
Expand Down Expand Up @@ -290,12 +297,23 @@ class ScalaToplevelMtags(
)
case DEF | VAL | VAR | GIVEN
if expectTemplate.map(!_.isExtension).getOrElse(true) =>
val isImplicit =
if (isInParenthesis(region))
expectTemplate.exists(
_.isImplicit
)
else region.isImplicit
if (needEmitTermMember()) {
withOwner(currRegion.termOwner) {
emitTerm(currRegion)
emitTerm(currRegion, isImplicit)
}
} else scanner.nextToken()
loop(indent, isAfterNewline = false, currRegion, newExpectIgnoreBody)
loop(
indent,
isAfterNewline = false,
currRegion,
if (isInParenthesis(region)) expectTemplate else newExpectIgnoreBody
)
case TYPE if expectTemplate.map(!_.isExtension).getOrElse(true) =>
if (needEmitMember(currRegion) && !prevWasDot) {
withOwner(currRegion.termOwner) {
Expand Down Expand Up @@ -408,15 +426,24 @@ class ScalaToplevelMtags(
expectTemplate match {
case Some(expect)
if needToParseBody(expect) || needToParseExtension(expect) =>
val next =
expect.startInBraceRegion(
currRegion,
expect.isExtension,
expect.isImplicit
)
resetRegion(next)
scanner.nextToken()
loop(indent, isAfterNewline = false, next, None)
if (isInParenthesis(region)) {
// inside of a class constructor
// e.g. class A(val foo: Foo { type T = Int })
// ^
acceptBalancedDelimeters(LBRACE, RBRACE)
scanner.nextToken()
loop(indent, isAfterNewline = false, currRegion, expectTemplate)
} else {
val next =
expect.startInBraceRegion(
currRegion,
expect.isExtension,
expect.isImplicit
)
resetRegion(next)
scanner.nextToken()
loop(indent, isAfterNewline = false, next, None)
}
case _ =>
acceptBalancedDelimeters(LBRACE, RBRACE)
scanner.nextToken()
Expand Down Expand Up @@ -716,7 +743,8 @@ class ScalaToplevelMtags(
/**
* Enters a global element (def/val/var/given)
*/
def emitTerm(region: Region): Unit = {
def emitTerm(region: Region, isParentImplicit: Boolean): Unit = {
val extensionProperty = if (isParentImplicit) EXTENSION else 0
val kind = scanner.curr.token
acceptTrivia()
kind match {
Expand All @@ -726,7 +754,7 @@ class ScalaToplevelMtags(
name.name,
name.pos,
Kind.METHOD,
SymbolInformation.Property.VAL.value
SymbolInformation.Property.VAL.value | extensionProperty
)
resetRegion(region)
})
Expand All @@ -736,7 +764,7 @@ class ScalaToplevelMtags(
name.name,
"()",
name.pos,
SymbolInformation.Property.VAR.value
SymbolInformation.Property.VAR.value | extensionProperty
)
resetRegion(region)
})
Expand All @@ -746,7 +774,7 @@ class ScalaToplevelMtags(
name.name,
region.overloads.disambiguator(name.name),
name.pos,
0
extensionProperty
)
)
case GIVEN =>
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/src/test/scala/tests/Example.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package tests

package a
object O {
trait Foo {
type T
}

implicit class A(val foo: Foo { type T = Int }) {
def get: Int = 1
}
}
57 changes: 47 additions & 10 deletions tests/unit/src/test/scala/tests/ScalaToplevelSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ class ScalaToplevelSuite extends BaseSuite {
List(
"a/",
"a/A.",
"a/A.bar().",
"a/A.foo().",
"a/A.bar(). EXT",
"a/A.foo(). EXT",
),
mode = All,
dialect = dialects.Scala3,
Expand All @@ -370,8 +370,8 @@ class ScalaToplevelSuite extends BaseSuite {
List(
"a/",
"a/Test$package.",
"a/Test$package.bar().",
"a/Test$package.foo().",
"a/Test$package.bar(). EXT",
"a/Test$package.foo(). EXT",
),
mode = All,
dialect = dialects.Scala3,
Expand All @@ -387,8 +387,8 @@ class ScalaToplevelSuite extends BaseSuite {
| def baz: Long = ???
|""".stripMargin,
List(
"a/", "a/Test$package.", "a/Test$package.foo().", "a/Test$package.bar().",
"a/Test$package.baz().",
"a/", "a/Test$package.", "a/Test$package.foo(). EXT",
"a/Test$package.bar(). EXT", "a/Test$package.baz(). EXT",
),
mode = All,
dialect = dialects.Scala3,
Expand Down Expand Up @@ -655,6 +655,42 @@ class ScalaToplevelSuite extends BaseSuite {
mode = All,
)

check(
"refined-type",
"""|package a
|object O {
| trait Foo {
| type T
| }
|
| implicit class A(val foo: Foo { type T = Int }) {
| def get: Int = 1
| }
|}
|""".stripMargin,
List(
"a/", "a/O.", "a/O.A#", "a/O.A#foo. EXT", "a/O.A#get(). EXT", "a/O.Foo#",
"a/O.Foo#T#",
),
mode = All,
)

check(
"implicit-class-with-val",
"""|package a
|object Foo {
| implicit class IntOps(private val i: Int) extends AnyVal {
| def inc: Int = i + 1
| }
|}
|""".stripMargin,
List(
"a/", "a/Foo.", "a/Foo.IntOps# -> AnyVal", "a/Foo.IntOps#i. EXT",
"a/Foo.IntOps#inc(). EXT",
),
mode = All,
)

def check(
options: TestOptions,
code: String,
Expand All @@ -672,11 +708,12 @@ class ScalaToplevelSuite extends BaseSuite {
val includeMembers = mode == All
val (doc, overrides) =
Mtags.indexWithOverrides(input, dialect, includeMembers)
val symbols = doc.occurrences.map(_.symbol).toList
val overriddenMap = overrides.toMap
symbols.map { symbol =>
doc.symbols.map { symbolInfo =>
val symbol = symbolInfo.symbol
val suffix = if (symbolInfo.isExtension) " EXT" else ""
overriddenMap.get(symbol) match {
case None => symbol
case None => s"$symbol$suffix"
case Some(symbols) =>
val overridden =
symbols
Expand All @@ -685,7 +722,7 @@ class ScalaToplevelSuite extends BaseSuite {
case UnresolvedOverriddenSymbol(name) => name
}
.mkString(", ")
s"$symbol -> $overridden"
s"$symbol$suffix -> $overridden"
}
}
case Toplevel => Mtags.topLevelSymbols(input, dialect)
Expand Down

0 comments on commit b2b0ac5

Please sign in to comment.