Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arbitrary arity tuple #435

Merged
merged 11 commits into from
May 22, 2024
63 changes: 46 additions & 17 deletions quill-sql/src/main/scala/io/getquill/generic/GenericDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ object GenericDecoder {
}
} // end flatten

// similar to flatten but without labels
@tailrec
def values[ResultRow: Type, Session: Type, Types: Type](
index: Int,
baseIndex: Expr[Int],
resultRow: Expr[ResultRow],
session: Expr[Session]
)(accum: List[FlattenData] = List())(using Quotes): List[FlattenData] = {
import quotes.reflect.{Term => QTerm, _}

Type.of[Types] match {
case '[tpe *: types] if Expr.summon[GenericDecoder[ResultRow, Session, tpe, DecodingType.Specific]].isEmpty =>
val result = decode[tpe, ResultRow, Session](index, baseIndex, resultRow, session)
val nextIndex = result.index + 1
values[ResultRow, Session, types](nextIndex, baseIndex, resultRow, session)(result +: accum)
case '[tpe *: types] =>
val result = decode[tpe, ResultRow, Session](index, baseIndex, resultRow, session, None)
val nextIndex = index + 1
values[ResultRow, Session, types](nextIndex, baseIndex, resultRow, session)(result +: accum)
case '[EmptyTuple] => accum

case typesTup => report.throwError("Cannot Derive Product during Values extraction:\n" + typesTup)
}
} // end values

def decodeOptional[T: Type, ResultRow: Type, Session: Type](index: Int, baseIndex: Expr[Int], resultRow: Expr[ResultRow], session: Expr[Session])(using Quotes): FlattenData = {
import quotes.reflect._
// Try to summon a specific optional from the context, this may not exist since
Expand Down Expand Up @@ -163,9 +188,9 @@ object GenericDecoder {
// List((new Name(Decoder("Joe") || Decoder("Bloggs")), Decoder(123))
// This is what needs to be fed into the constructor of the outer-entity i.e.
// new Person((new Name(Decoder("Joe") || Decoder("Bloggs")), Decoder(123))
val productElments = flattenData.map(_.decodedExpr)
val productElements = flattenData.map(_.decodedExpr)
// actually doing the construction i.e. `new Person(...)`
val constructed = ConstructDecoded[T](types, productElments, m)
val constructed = ConstructDecoded[T](types, productElements, m)

// E.g. for Person("Joe", 123) the List(q"!nullChecker(0,row)", q"!nullChecker(1,row)") columns
// that eventually turn into List(!NullChecker("Joe"), !NullChecker(123)) columns.
Expand All @@ -192,6 +217,11 @@ object GenericDecoder {
TypeRepr.of[T] <:< TypeRepr.of[Option[Any]]
}

private def isTuple[T: Type](using Quotes) = {
import quotes.reflect._
TypeRepr.of[T] <:< TypeRepr.of[Tuple]
}

private def isBuiltInType[T: Type](using Quotes) = {
import quotes.reflect._
isOption[T] || (TypeRepr.of[T] <:< TypeRepr.of[Seq[_]])
Expand All @@ -207,6 +237,16 @@ object GenericDecoder {
case '[Option[tpe]] =>
decodeOptional[tpe, ResultRow, Session](index, baseIndex, resultRow, session)
}
} else if (isTuple[T]) {
if (TypeRepr.of[T] <:< TypeRepr.of[EmptyTuple]) {
FlattenData(Type.of[T], '{ EmptyTuple }, '{ false }, index)
} else {
val flattenData = values[ResultRow, Session, T](index, baseIndex, resultRow, session)().reverse
val elementTerms = flattenData.map(_.decodedExpr) // expressions that represent values for tuple elements
val constructed = '{ scala.runtime.Tuples.fromArray(${ Varargs(elementTerms) }.toArray[Any](Predef.summon[ClassTag[Any]]).asInstanceOf[Array[Object]]).asInstanceOf[T] }
val nullChecks = flattenData.map(_._3).reduce((a, b) => '{ $a || $b })
FlattenData(Type.of[T], constructed, nullChecks, flattenData.last.index)
}
} else {
// specifically if there is a decoder found, allow optional override of the index via a resolver
val decoderIndex = overriddenIndex.getOrElse(elementIndex)
Expand Down Expand Up @@ -341,21 +381,10 @@ object ConstructDecoded {
val tpe = TypeRepr.of[T]
val constructor = TypeRepr.of[T].typeSymbol.primaryConstructor
// If we are a tuple, we can easily construct it
if (tpe <:< TypeRepr.of[Tuple]) {
val construct =
Apply(
TypeApply(
Select(New(TypeTree.of[T]), constructor),
types.map { tpe =>
tpe match {
case '[tt] => TypeTree.of[tt]
}
}
),
terms.map(_.asTerm)
)
// println(s"=========== Create from Tuple Constructor ${Format.Expr(construct.asExprOf[T])} ===========")
construct.asExprOf[T]
if (tpe <:< TypeRepr.of[EmptyTuple]) {
'{EmptyTuple}
} else if (tpe <:< TypeRepr.of[Tuple]) {
'{scala.runtime.Tuples.fromIArray(IArray(${Varargs(terms)})).asInstanceOf[T]}
// If we are a case class with no generic parameters, we can easily construct it
} else if (tpe.classSymbol.exists(_.flags.is(Flags.Case)) && !constructor.paramSymss.exists(_.exists(_.isTypeParam))) {
val construct =
Expand Down
108 changes: 108 additions & 0 deletions quill-sql/src/main/scala/io/getquill/parser/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ trait ParserLibrary extends ParserFactory {
protected def functionParser(using Quotes, TranspileConfig) = ParserChain.attempt(FunctionParser(_))
protected def functionApplyParser(using Quotes, TranspileConfig) = ParserChain.attempt(FunctionApplyParser(_))
protected def valParser(using Quotes, TranspileConfig) = ParserChain.attempt(ValParser(_))
protected def arbitraryTupleParser(using Quotes, TranspileConfig) = ParserChain.attempt(ArbitraryTupleBlockParser(_))
protected def blockParser(using Quotes, TranspileConfig) = ParserChain.attempt(BlockParser(_))
protected def extrasParser(using Quotes, TranspileConfig) = ParserChain.attempt(ExtrasParser(_))
protected def operationsParser(using Quotes, TranspileConfig) = ParserChain.attempt(OperationsParser(_))
Expand Down Expand Up @@ -88,6 +89,7 @@ trait ParserLibrary extends ParserFactory {
.orElse(functionParser) // decided to have it be it's own parser unlike Quill3
.orElse(patMatchParser)
.orElse(valParser)
.orElse(arbitraryTupleParser)
.orElse(blockParser)
.orElse(operationsParser)
.orElse(extrasParser)
Expand Down Expand Up @@ -153,6 +155,112 @@ class ValParser(val rootParse: Parser)(using Quotes, TranspileConfig)
case Unseal(ValDefTerm(ast)) => ast
}
}
/**
* Matches `runtime.Tuples.cons(head,tail)`.
*/
object TupleCons {
def unapply(using Quotes)(t: quotes.reflect.Term): Option[(quotes.reflect.Term, quotes.reflect.Term)] = {
import quotes.reflect.*
t match {
case Apply(Select(Select(Ident("runtime"), "Tuples"), "cons"), List(head, tail)) =>
Some((head, tail))
case _ =>
None
}
}
}

/**
* Matches inner.asInstanceOf[T]: T
*/
object AsInstanceOf {
def unapply(using Quotes)(term: quotes.reflect.Term): Option[quotes.reflect.Term] = {
import quotes.reflect._
term match {
case TypeApply(Select(inner, "asInstanceOf"), _) => Some(inner)
case _ => None
}
}
}

/**
* Matches an inlined call to `Tuple.*:`:
* {{{
* {
* val Tuple_this: scala.Tuple$package.EmptyTuple.type = scala.Tuple$package.EmptyTuple
*
* (scala.runtime.Tuples.cons(i, Tuple_this).asInstanceOf[scala.*:[scala.Int, scala.Tuple$package.EmptyTuple.type]]: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple])
* }
* }}}
*/
object ArbitraryTupleConstructionInlined {
def unapply(using Quotes)(t: quotes.reflect.Term): Option[(quotes.reflect.Term, quotes.reflect.Term)] = {
import quotes.reflect.{Ident => TIdent, *}
t match {
case
Inlined(
_,
List(ValDef("Tuple_this", _, Some(prevTuple))),
Typed(AsInstanceOf(TupleCons(head, TIdent("Tuple_this"))), _)
) =>
Some((head, prevTuple))
case _ => None
}
}
}

/**
* Parses a few cases of arbitrary tuples.
*
* Scala 3 produces a few different trees for arbitrary tuples. Method `*:` is marked as inline.
* Under the hood it actually invokes `Tuples.cons` function:
* {{{
* inline def *: [H, This >: this.type <: Tuple] (x: H): H *: This =
* runtime.Tuples.cons(x, this).asInstanceOf[H *: This]
* }}}
* So, at least we have to match Tuples.cons.
* However, it's not the only variation. Scala also produces a block with intermediate val `Tuple_this` definitions:
* {{{
* {
* val Tuple_this: scala.Tuple$package.EmptyTuple.type = scala.Tuple$package.EmptyTuple
* val `Tuple_this₂`: scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple] = (scala.runtime.Tuples.cons("", Tuple_this).asInstanceOf[scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple.type]]: scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple])
*
* (scala.runtime.Tuples.cons(1, `Tuple_this₂`).asInstanceOf[scala.*:[scala.Int, scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple]]]: scala.*:[scala.Int, scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple]])
* }
* }}}
*/
class ArbitraryTupleBlockParser(val rootParse: Parser)(using Quotes, TranspileConfig)
extends Parser(rootParse)
with PatternMatchingValues {

import quotes.reflect.{Block => TBlock, Ident => TIdent, _}

def attempt = {
case '{EmptyTuple} =>
ast.Tuple(List())
case '{$a *: EmptyTuple} =>
val aAst = rootParse(a)
ast.Tuple(List(aAst))
case inlined@Unseal(ArbitraryTupleConstructionInlined(singleValue, prevTuple)) =>
val headAst = rootParse(singleValue.asExpr)
val prevTupleAst = rootParse(prevTuple.asExpr)
prevTupleAst match {
case ast.Tuple(lst) => ast.Tuple(headAst :: lst)
case _ =>
throw IllegalArgumentException(s"Unexpected tuple ast ${prevTupleAst}")
}
case block@Unseal(TBlock(parts, ArbitraryTupleConstructionInlined(head,Typed(AsInstanceOf(TupleCons(head2,TIdent("Tuple_this"))), _)) )) if (parts.length > 0) =>
val headAst = rootParse(head.asExpr)
val head2Ast = rootParse(head2.asExpr)
val partsAsts = headAst :: head2Ast :: parts.reverse.flatMap{
case ValDef("Tuple_this", tpe, Some(TIdent("EmptyTuple"))) => List()
case ValDef("Tuple_this", tpe, Some(Typed(AsInstanceOf(TupleCons(next,TIdent("Tuple_this"))), _))) => List(rootParse(next.asExpr))
case ValDef("Tuple_this", tpe, Some(unknown)) =>
throw IllegalArgumentException(s"Unexpected Tuple_this = ${unknown.show}")
}
Tuple(partsAsts)
}
}

class BlockParser(val rootParse: Parser)(using Quotes, TranspileConfig)
extends Parser(rootParse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,12 @@ object ParserHelpers {
binds.zipWithIndex.flatMap { case (bind, idx) =>
tupleBindsPath(bind, path :+ s"_${idx + 1}")
}
case Unapply(TypeApply(Select(TIdent("*:"), "unapply"), types), implicits, List(h, t)) =>
List(
tupleBindsPath(h, path :+ s"head"),
tupleBindsPath(h, path :+ s"tail")
)
.flatten
// If it's a "case _ => ..." then that just translates into the body expression so we don't
// need a clause to beta reduction over the entire partial-function
case TIdent("_") =>
Expand Down
22 changes: 22 additions & 0 deletions quill-sql/src/main/scala/io/getquill/quat/QuatMaking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,25 @@ trait QuatMakingBase {
None
}

object ArbitraryArityTupleType {
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[List[quotes.reflect.TypeRepr]] =
if (tpe.is[Tuple])
Some(tupleParts(tpe))
else
None

@tailrec
def tupleParts(using Quotes)(tpe: quotes.reflect.TypeRepr, accum: List[quotes.reflect.TypeRepr] = Nil): List[quotes.reflect.TypeRepr] =
tpe.asType match {
case '[h *: t] =>
val htpe = quotes.reflect.TypeRepr.of[h]
val ttpe = quotes.reflect.TypeRepr.of[t]
tupleParts(ttpe, htpe :: accum)
case '[EmptyTuple] =>
accum.reverse
}
}

object OptionType {
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[quotes.reflect.TypeRepr] = {
import quotes.reflect._
Expand Down Expand Up @@ -384,6 +403,9 @@ trait QuatMakingBase {
case CaseClassBaseType(name, fields) if !existsEncoderFor(tpe) || tpe <:< TypeRepr.of[Udt] =>
Quat.Product(name, fields.map { case (fieldName, fieldType) => (fieldName, parseType(fieldType)) })

case ArbitraryArityTupleType(tupleParts) =>
Quat.Product("Tuple", tupleParts.zipWithIndex.map { case (fieldType, idx) => (s"_${idx + 1}", parseType(fieldType)) })

// If we are already inside a bounded type, treat an arbitrary type as a interface list
case ArbitraryBaseType(name, fields) if (boundedInterfaceType) =>
Quat.Product(name, fields.map { case (fieldName, fieldType) => (fieldName, parseType(fieldType)) })
Expand Down
122 changes: 122 additions & 0 deletions quill-sql/src/test/scala/io/getquill/ArbitraryTupleSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package io.getquill

import io.getquill.context.ExecutionType.Static
import io.getquill.context.mirror.{MirrorSession, Row}
import io.getquill.generic.TupleMember

class ArbitraryTupleSpec extends Spec {

val ctx = new MirrorContext(PostgresDialect, Literal)
import ctx._

type MyRow1 = (Int, String)
type MyRow2 = Int *: String *: EmptyTuple

inline def myRow1Query = quote {
querySchema[MyRow1]("my_table", t => t._1 -> "int_field", t => t._2 -> "string_field")
}

inline def myRow2Query = quote {
querySchema[MyRow2]("my_table", t => t._1 -> "int_field", t => t._2 -> "string_field")
}

"ordinary tuple" in {
val result = ctx.run(myRow1Query)

result.string mustEqual "SELECT x.int_field, x.string_field FROM my_table x"
result.extractor(Row(123, "St"), MirrorSession.default) mustEqual
(123, "St")
}

"ordinary tuple swap" in {

transparent inline def swapped: Quoted[EntityQuery[(String, Int)]] = quote {
myRow1Query.map {
case (i, s) => (s, i)
}
}

val result = ctx.run(swapped)

result.string mustEqual "SELECT x$1.string_field AS _1, x$1.int_field AS _2 FROM my_table x$1"
require(result.extractor(Row("St", 123), MirrorSession.default) == ("St", 123))
}

"arbitrary long tuple" in {
val result = ctx.run(myRow2Query)

result.extractor(Row(123, "St"), MirrorSession.default) mustEqual
(123, "St")
}

"get field of arbitrary long tuple" in {
inline def g = quote{
myRow2Query.map{
case h *: tail => h
}
}
val result = ctx.run(g)

result.extractor(Row(123, "St"), MirrorSession.default) mustEqual
(123)
}

"decode empty tuple" in {
inline def g = quote {
myRow2Query.map{
case (_, _) => EmptyTuple
}
}

val result = ctx.run(g)

result.extractor(Row(123, "St"), MirrorSession.default) mustEqual EmptyTuple
}

"construct tuple1" in {
inline def g = quote {
myRow1Query.map {
case (i, s) => i *: EmptyTuple
}
}

val result = ctx.run(g)

result.string mustEqual "SELECT x$1.int_field AS _1 FROM my_table x$1"
result.extractor(Row(123, "St"), MirrorSession.default) mustEqual
Tuple1(123)
}

"construct arbitrary tuple" in {
inline def g = quote {
myRow1Query.map {
case (i, s) => s *: i *: EmptyTuple
}
}
val result = ctx.run(g)

result.string mustEqual "SELECT x$1.string_field AS _1, x$1.int_field AS _2 FROM my_table x$1"
result.extractor(Row("St", 123), MirrorSession.default) mustEqual ("St", 123)

}

"constant arbitrary tuple" in {
inline def g = quote {
123 *: "St" *: true *: (3.14 *: EmptyTuple)
}
val result = ctx.run(g)
result.string mustEqual "SELECT 123 AS _1, 'St' AS _2, true AS _3, 3.14 AS _4"
result.info.executionType mustEqual Static
result.extractor(Row(123, "St", true, 3.14), MirrorSession.default) mustEqual (123, "St", true, 3.14)
}

"constant arbitrary tuple 1" in {
inline def g = quote {
(3.14 *: EmptyTuple)
}
val result = ctx.run(g)
result.string mustEqual "SELECT 3.14 AS _1"
result.info.executionType mustEqual Static
result.extractor(Row(3.14), MirrorSession.default) mustEqual Tuple1(3.14)
}
}