Skip to content

Commit

Permalink
feat: Better diagnostics (#242)
Browse files Browse the repository at this point in the history
Enhance error messages and support some other cases at compile time.

Partly closes #238. Other cases are not doable because Scala cancel any
folding/inlining for them.
  • Loading branch information
Iltotore authored Jun 20, 2024
1 parent 623835e commit 37336d9
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 156 deletions.
271 changes: 148 additions & 123 deletions main/src/io/github/iltotore/iron/macros/ReflectUtil.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.github.iltotore.iron.macros

import io.github.iltotore.iron.compileTime.NumConstant

import scala.quoted.*

/**
Expand All @@ -22,6 +20,15 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):

import _quotes.reflect.*

extension [T: Type](expr: Expr[T])

/**
* Decode this expression.
*
* @return the value of this expression found at compile time or a [[DecodingFailure]]
*/
def decode: Either[DecodingFailure, T] = ExprDecoder.decodeTerm(expr.asTerm, Map.empty)

/**
* A decoding failure.
*/
Expand Down Expand Up @@ -62,6 +69,8 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
*/
case ApplyNotInlined(name: String, parameters: List[Either[DecodingFailure, ?]])

case VarArgsNotInlined(args: List[Either[DecodingFailure, ?]])

/**
* A boolean OR is not inlined.
*
Expand Down Expand Up @@ -90,14 +99,19 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
*/
case InterpolatorNotInlined(name: String)

/**
* An unknown failure.
*/
case Unknown

/**
* Pretty print this failure.
*
* @param bodyIdent the identation of the 2nd+ lines
* @param firstLineIdent the identation of the first line
* @return a pretty-formatted [[String]] representation of this failure
*/
def prettyPrint(bodyIdent: Int = 0, firstLineIdent: Int = 0): String =
def prettyPrint(bodyIdent: Int = 0, firstLineIdent: Int = 0)(using Printer[Tree]): String =
val unindented = this match
case NotInlined(term) => s"Term not inlined: ${term.show}"
case DefinitionNotInlined(name) => s"Definition not inlined: $name. Only vals and zero-arg def can be inlined."
Expand All @@ -117,29 +131,38 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):

s"Some arguments of `$name` are not inlined:\n$errors"

case VarArgsNotInlined(args) =>
val errors = args
.zipWithIndex
.collect:
case (Left(failure), i) => s"Arg $i:\n${failure.prettyPrint(2, 2)}"
.mkString("\n\n")

s"Some varargs are not inlined:\n$errors"

case OrNotInlined(left, right) =>
s"""Non-inlined boolean or. The following patterns are evaluable at compile-time:
|- <inlined value> || <inlined value>
|- <inlined value> || true
|- true || <inlined value>
|
|Left member:
|${left.fold(_.prettyPrint(2, 2), _.toString)}
|${left.fold(_.prettyPrint(2, 2), " " + _)}
|
|Right member:
|${right.fold(_.prettyPrint(2, 2), _.toString)}""".stripMargin
|${right.fold(_.prettyPrint(2, 2), " " + _)}""".stripMargin

case AndNotInlined(left, right) =>
s"""Non-inlined boolean or. The following patterns are evaluable at compile-time:
|- <inlined value> || <inlined value>
|- <inlined value> || true
|- true || <inlined value>
s"""Non-inlined boolean and. The following patterns are evaluable at compile-time:
|- <inlined value> && <inlined value>
|- <inlined value> && false
|- false && <inlined value>
|
|Left member:
|${left.fold(_.prettyPrint(2, 2), _.toString)}
|${left.fold(_.prettyPrint(2, 2), " " + _)}
|
|Right member:
|${right.fold(_.prettyPrint(2, 2), _.toString)}""".stripMargin
|${right.fold(_.prettyPrint(2, 2), " " + _)}""".stripMargin

case StringPartsNotInlined(parts) =>
val errors = parts
Expand All @@ -148,155 +171,157 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
case (Left(failure), i) => s"Arg $i:\n${failure.prettyPrint(2, 2)}"
.mkString("\n\n")

s"String contatenation as non inlined arguments:\n$errors"
s"String contatenation has non inlined arguments:\n$errors"

case InterpolatorNotInlined(name) => s"This interpolator is not supported: $name. Only `s` and `raw` are supported."

case Unknown => "Unknown reason"

" " * firstLineIdent + unindented.replaceAll("(\r\n|\n|\r)", "$1" + " " * bodyIdent)

override def toString: String = prettyPrint()
object ExprDecoder:

/**
* A compile-time [[Expr]] decoder. Like [[FromExpr]] with more fine-grained errors.
*
* @tparam T the type of the expression to decodeExpr
*/
trait ExprDecoder[T]:
private val enhancedDecoders: Map[TypeRepr, (Term, Map[String, ?]) => Either[DecodingFailure, ?]] = Map(
TypeRepr.of[Boolean] -> decodeBoolean,
TypeRepr.of[String] -> decodeString
)

/**
* Decode the given expression.
* Decode a term.
*
* @param expr the expression to decodeExpr
* @return the value decoded from [[expr]] or a [[DecodingFailure]] instead
* @param tree the term to decode
* @param definitions the decoded definitions in scope
* @tparam T the expected type of this term used as implicit cast for convenience
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeExpr(expr: Expr[T]): Either[DecodingFailure, T]

extension [T](expr: Expr[T])

def decode(using decoder: ExprDecoder[T]): Either[DecodingFailure, T] = decoder.decodeExpr(expr)

object ExprDecoder:
def decodeTerm[T](tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, T] =
val specializedResult = enhancedDecoders
.collectFirst:
case (k, v) if k =:= tree.tpe => v
.toRight(DecodingFailure.Unknown)
.flatMap(_.apply(tree, definitions))

specializedResult match
case Left(DecodingFailure.Unknown) => decodeUnspecializedTerm(tree, definitions)
case result => result.asInstanceOf[Either[DecodingFailure, T]]

/**
* Fallback expression decoder instance using Dotty's [[FromExpr]]. Fails with a [[DecodingFailure.NotInlined]] if the
* underlying [[FromExpr]] returns [[None]].
* Decode a term using only unspecialized cases.
*
* @param tree the term to decode
* @param definitions the decoded definitions in scope
* @tparam T the expected type of this term used as implicit cast for convenience
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
given [T](using fromExpr: FromExpr[T]): ExprDecoder[T] with

override def decodeExpr(expr: Expr[T]): Either[DecodingFailure, T] =
fromExpr.unapply(expr).toRight(DecodingFailure.NotInlined(expr.asTerm))

private class PrimitiveExprDecoder[T <: NumConstant | Byte | Short | Boolean | String : Type] extends ExprDecoder[T]:

private def decodeBinding(definition: Definition): Either[DecodingFailure, T] = definition match
case ValDef(name, tpeTree, Some(term)) if tpeTree.tpe <:< TypeRepr.of[T] => decodeTerm(term)
case DefDef(name, Nil, tpeTree, Some(term)) if tpeTree.tpe <:< TypeRepr.of[T] => decodeTerm(term)
case _ => Left(DecodingFailure.DefinitionNotInlined(definition.name))

def decodeTerm(tree: Term): Either[DecodingFailure, T] = tree match
case block@Block(stats, e) => if stats.isEmpty then decodeTerm(e) else Left(DecodingFailure.HasStatements(block))
def decodeUnspecializedTerm[T](tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, T] =
tree match
case block@Block(stats, e) => if stats.isEmpty then decodeTerm(e, definitions) else Left(DecodingFailure.HasStatements(block))

case Inlined(_, bindings, e) =>
val failures =
for
binding <- bindings
failure <- decodeBinding(binding).left.toOption
yield
(binding.name, failure)

if failures.isEmpty then decodeTerm(e)
else Left(DecodingFailure.HasBindings(failures))
val (failures, values) = bindings
.map[(String, Either[DecodingFailure, ?])](b => (b.name, decodeBinding(b, definitions)))
.partitionMap:
case (name, Right(value)) => Right((name, value))
case (name, Left(failure)) => Left((name, failure))

(failures, decodeTerm[T](e, definitions ++ values.toMap)) match
case (_, Right(value)) =>
Right(value)
case (Nil, Left(failure)) => Left(failure)
case (failures, Left(_)) => Left(DecodingFailure.HasBindings(failures))

case Apply(Select(left, "=="), List(right)) => (decodeTerm[Any](left, definitions), decodeTerm[Any](right, definitions)) match
case (Right(leftValue), Right(rightValue)) => Right((leftValue == rightValue).asInstanceOf[T])
case (leftResult, rightResult) => Left(DecodingFailure.ApplyNotInlined("==", List(leftResult, rightResult)))

case Typed(e, _) => decodeTerm(e)
case Apply(Select(leftOperand, name), operands) =>
val rightResults = operands.map(decodeTerm)
val rightResults = operands.map(decodeTerm(_, definitions))

val allResults = decodeTerm(leftOperand) match
val allResults = decodeTerm(leftOperand, definitions) match
case Left(DecodingFailure.ApplyNotInlined(n, leftResults)) if n == name =>
leftResults ++ rightResults
case leftResult =>
leftResult +: rightResults

Left(DecodingFailure.ApplyNotInlined(name, allResults))

case Repeated(terms, _) =>
var hasFailure = false
val results =
for term <- terms yield
val result = decodeTerm(term, definitions)
if result.isLeft then hasFailure = true
result

if hasFailure then Left(DecodingFailure.VarArgsNotInlined(results))
else Right(results.map(_.getOrElse((???): String)).asInstanceOf[T])

case Typed(e, _) => decodeTerm(e, definitions)

case Ident(name) => definitions
.get(name)
.toRight(DecodingFailure.NotInlined(tree))
.asInstanceOf[Either[DecodingFailure, T]]

case _ =>
tree.tpe.widenTermRefByName match
case ConstantType(c) => Right(c.value.asInstanceOf[T])
case _ => Left(DecodingFailure.NotInlined(tree))

override def decodeExpr(expr: Expr[T]): Either[DecodingFailure, T] =
decodeTerm(expr.asTerm)

/**
* Decoder for all primitives except for [[String]] and [[Boolean]] which benefit from some enhancements.
*
* @tparam T the type of the expression to decodeExpr
* Decode a binding/definition.
*
* @param definition the definition to decode
* @param definitions the definitions already decoded in scope
* @tparam T the expected type of this term used as implicit cast for convenience
* @return the value of the given definition found at compile time or a [[DecodingFailure]]
*/
given [T <: NumConstant | Byte | Short : Type]: ExprDecoder[T] = new PrimitiveExprDecoder[T]
def decodeBinding[T](definition: Definition, definitions: Map[String, ?]): Either[DecodingFailure, T] = definition match
case ValDef(name, tpeTree, Some(term)) => decodeTerm(term, definitions)
case DefDef(name, Nil, tpeTree, Some(term)) => decodeTerm(term, definitions)
case _ => Left(DecodingFailure.DefinitionNotInlined(definition.name))

/**
* A boolean [[ExprDecoder]] that can extract value from partially inlined || and
* && operations.
*
* {{{
* inline val x = true
* val y: Boolean = ???
*
* x || y //inlined to `true`
* y || x //inlined to `true`
* Decode a [[Boolean]] term using only [[Boolean]]-specific cases.
*
* inline val a = false
* val b: Boolean = ???
*
* a && b //inlined to `false`
* b && a //inlined to `false`
* }}}
* @param term the term to decode
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
given ExprDecoder[Boolean] = new PrimitiveExprDecoder[Boolean]:

override def decodeTerm(tree: Term): Either[DecodingFailure, Boolean] = tree match
case Apply(Select(left, "||"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // OR
(decodeTerm(left), decodeTerm(right)) match
case (Right(true), _) => Right(true)
case (_, Right(true)) => Right(true)
case (Right(leftValue), Right(rightValue)) => Right(leftValue || rightValue)
case (leftResult, rightResult) => Left(DecodingFailure.OrNotInlined(leftResult, rightResult))

case Apply(Select(left, "&&"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // AND
(decodeTerm(left), decodeTerm(right)) match
case (Right(false), _) => Right(false)
case (_, Right(false)) => Right(false)
case (Right(leftValue), Right(rightValue)) => Right(leftValue && rightValue)
case (leftResult, rightResult) => Left(DecodingFailure.AndNotInlined(leftResult, rightResult))

case _ => super.decodeTerm(tree)
def decodeBoolean(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, Boolean] = term match
case Apply(Select(left, "||"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // OR
(decodeTerm[Boolean](left, definitions), decodeTerm[Boolean](right, definitions)) match
case (Right(true), _) => Right(true)
case (_, Right(true)) => Right(true)
case (Right(leftValue), Right(rightValue)) => Right(leftValue || rightValue)
case (leftResult, rightResult) => Left(DecodingFailure.OrNotInlined(leftResult, rightResult))

case Apply(Select(left, "&&"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // AND
(decodeTerm[Boolean](left, definitions), decodeTerm[Boolean](right, definitions)) match
case (Right(false), _) => Right(false)
case (_, Right(false)) => Right(false)
case (Right(leftValue), Right(rightValue)) => Right(leftValue && rightValue)
case (leftResult, rightResult) => Left(DecodingFailure.AndNotInlined(leftResult, rightResult))

case _ => Left(DecodingFailure.Unknown)

/**
* A String [[ExprDecoder]] that can extract value from concatenated strings if all
* arguments are compile-time-extractable strings.
*
* {{{
* inline val x = "a"
* inline val y = "b"
* val z = "c"
* Decode a [[String]] term using only [[String]]-specific cases.
*
* x + y //"ab"
* x + z //DecodingFailure
* z + x //DecodingFailure
* }}}
* @param term the term to decode
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
given ExprDecoder[String] = new PrimitiveExprDecoder[String]:

override def decodeTerm(tree: Term): Either[DecodingFailure, String] = tree match
case Apply(Select(left, "+"), List(right)) if left.tpe <:< TypeRepr.of[String] && right.tpe <:< TypeRepr.of[String] =>
(decodeTerm(left), decodeTerm(right)) match
case (Right(leftValue), Right(rightValue)) => Right(leftValue + rightValue)
case (Left(DecodingFailure.StringPartsNotInlined(lparts)), Left(DecodingFailure.StringPartsNotInlined(rparts))) =>
Left(DecodingFailure.StringPartsNotInlined(lparts ++ rparts))
case (Left(DecodingFailure.StringPartsNotInlined(lparts)), rightResult) =>
Left(DecodingFailure.StringPartsNotInlined(lparts :+ rightResult))
case (leftResult, Left(DecodingFailure.StringPartsNotInlined(rparts))) =>
Left(DecodingFailure.StringPartsNotInlined(leftResult +: rparts))
case (leftResult, rightResult) => Left(DecodingFailure.StringPartsNotInlined(List(leftResult, rightResult)))

case _ => super.decodeTerm(tree)
def decodeString(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, String] = term match
case Apply(Select(left, "+"), List(right)) if left.tpe <:< TypeRepr.of[String] && right.tpe <:< TypeRepr.of[String] =>
(decodeTerm[String](left, definitions), decodeTerm[String](right, definitions)) match
case (Right(leftValue), Right(rightValue)) => Right(leftValue + rightValue)
case (Left(DecodingFailure.StringPartsNotInlined(lparts)), Left(DecodingFailure.StringPartsNotInlined(rparts))) =>
Left(DecodingFailure.StringPartsNotInlined(lparts ++ rparts))
case (Left(DecodingFailure.StringPartsNotInlined(lparts)), rightResult) =>
Left(DecodingFailure.StringPartsNotInlined(lparts :+ rightResult))
case (leftResult, Left(DecodingFailure.StringPartsNotInlined(rparts))) =>
Left(DecodingFailure.StringPartsNotInlined(leftResult +: rparts))
case (leftResult, rightResult) => Left(DecodingFailure.StringPartsNotInlined(List(leftResult, rightResult)))

case _ => Left(DecodingFailure.Unknown)
Loading

0 comments on commit 37336d9

Please sign in to comment.