diff --git a/README.md b/README.md index fb8c679..baef12b 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,8 @@ val q3 = query[Person](p => Pattern.matches("^[\\w-\\.]+@([\\w-]+\\.)+[\\w-]{2,4 V Array Query Operators +1. $size + ```scala import oolong.dsl.* @@ -204,6 +206,21 @@ val q = query[Course](_.studentNames.length == 20) // q is {"studentNames": {"$size": 20}} ``` +2. $elemMatch + +```scala +import oolong.dsl.* + +case class Course(studentNames: List[String], tutor: String) + +val q = query[Course](_.studentNames.exists(_ == 20)) // $elemMatch ommited when querying single field +// q is {"studentNames": 20} + +val q = query[Course](course => course.studentNames.exists(_ > 20) && course.tutor == "Pavlov") +// q is {"studentNames": {"$elemMatch": {"studentNames": {"$gt": 20}, "tutor": "Pavlov"}}} + +``` + #### Update operators I Field Update Operators diff --git a/oolong-core/src/main/scala/oolong/AstParser.scala b/oolong-core/src/main/scala/oolong/AstParser.scala index a32f3f6..2f9e331 100644 --- a/oolong-core/src/main/scala/oolong/AstParser.scala +++ b/oolong-core/src/main/scala/oolong/AstParser.scala @@ -5,7 +5,6 @@ import scala.annotation.tailrec import scala.language.postfixOps import scala.quoted.* -import oolong.AstParser import oolong.UExpr.FieldUpdateExpr import oolong.Utils.* import oolong.dsl.* @@ -63,6 +62,9 @@ private[oolong] class DefaultAstParser(using quotes: Quotes) extends AstParser { case '{ ($x: Seq[_]).length == ($y: Int) } => QExpr.Size(parse(x), parse(y)) + case '{ type t; ($x: Seq[`t`]).exists($y: (`t` => Boolean)) } => // not text & where + QExpr.ElemMatch(parse(x), parseQExpr(y)) + case AsTerm(Apply(Select(lhs, "<="), List(rhs))) => QExpr.Lte(parse(lhs.asExpr), parse(rhs.asExpr)) diff --git a/oolong-core/src/main/scala/oolong/Backend.scala b/oolong-core/src/main/scala/oolong/Backend.scala index b03475e..403c720 100644 --- a/oolong-core/src/main/scala/oolong/Backend.scala +++ b/oolong-core/src/main/scala/oolong/Backend.scala @@ -24,5 +24,5 @@ private[oolong] trait Backend[Ast, OptimizableRepr, TargetRepr] { /** * Perform optimizations that are specific to this backend. */ - def optimize(query: OptimizableRepr): OptimizableRepr = query + def optimize(query: OptimizableRepr)(using quotes: Quotes): OptimizableRepr = query } diff --git a/oolong-core/src/main/scala/oolong/LogicalOptimizer.scala b/oolong-core/src/main/scala/oolong/LogicalOptimizer.scala index 4293cc0..06f5507 100644 --- a/oolong-core/src/main/scala/oolong/LogicalOptimizer.scala +++ b/oolong-core/src/main/scala/oolong/LogicalOptimizer.scala @@ -35,11 +35,12 @@ private[oolong] object LogicalOptimizer { } ast match { - case QExpr.And(children) => flatten(QExpr.And(children.map(optimize))) - case QExpr.Or(children) => flatten(QExpr.Or(children.map(optimize))) - case QExpr.Not(QExpr.Not(e)) => e - case QExpr.Not(QExpr.Eq(l, r)) => QExpr.Ne(l, r) - case _ => ast + case QExpr.ElemMatch(field, expr) => QExpr.ElemMatch(field, flatten(expr)) + case QExpr.And(children) => flatten(QExpr.And(children.map(optimize))) + case QExpr.Or(children) => flatten(QExpr.Or(children.map(optimize))) + case QExpr.Not(QExpr.Not(e)) => e + case QExpr.Not(QExpr.Eq(l, r)) => QExpr.Ne(l, r) + case _ => ast } } } diff --git a/oolong-core/src/main/scala/oolong/QExpr.scala b/oolong-core/src/main/scala/oolong/QExpr.scala index 25d1d62..7eda114 100644 --- a/oolong-core/src/main/scala/oolong/QExpr.scala +++ b/oolong-core/src/main/scala/oolong/QExpr.scala @@ -51,4 +51,6 @@ private[oolong] object QExpr { case class TypeCheck[T](x: QExpr, typeInfo: TypeInfo[T]) extends QExpr case class Mod(x: QExpr, divisor: QExpr, remainder: QExpr) extends QExpr + + case class ElemMatch(x: QExpr, y: QExpr) extends QExpr } diff --git a/oolong-mongo/src/main/scala/oolong/mongo/MongoQueryCompiler.scala b/oolong-mongo/src/main/scala/oolong/mongo/MongoQueryCompiler.scala index 4255574..6e5d18b 100644 --- a/oolong-mongo/src/main/scala/oolong/mongo/MongoQueryCompiler.scala +++ b/oolong-mongo/src/main/scala/oolong/mongo/MongoQueryCompiler.scala @@ -1,13 +1,14 @@ package oolong.mongo import java.util.regex.Pattern +import scala.annotation.tailrec import scala.jdk.CollectionConverters.* import scala.quoted.Expr import scala.quoted.Quotes import scala.quoted.Type +import scala.util.chaining.* import oolong.* -import oolong.TypeInfo import oolong.Utils.PatternInstance.given import oolong.bson.* import oolong.bson.meta.QueryMeta @@ -96,19 +97,27 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { report.errorAndAbort(s"Expected the subquery inside 'unchecked(...)' to have 'org.mongodb.scala.bson.BsonDocument' type, but the subquery is '${code.show}'") } case not: QExpr.Not => handleInnerNot(not)(renames) + case elemMatch: QExpr.ElemMatch => + val cond = rec(elemMatch.y, renames) + val query = cond match + case f: MQ.OnField => MQ.ElemMatch(f) + case and: MQ.And => MQ.ElemMatch(and) + case or: MQ.Or => MQ.ElemMatch(or) + case _ => report.errorAndAbort(s"Wrong condition: ${cond}") + MQ.OnField(getField(elemMatch.x)(renames), query) } rec(ast, meta) } - def getField(f: QExpr)(renames: Map[String, String])(using quotes: Quotes): MQ.Field = + private def getField(f: QExpr)(renames: Map[String, String])(using quotes: Quotes): MQ.Field = import quotes.reflect.* f match case QExpr.Prop(path) => MQ.Field(renames.getOrElse(path, path)) - case _ => report.errorAndAbort("Field is of wrong type") + case expr => report.errorAndAbort(s"Field is of wrong type: ${expr}") - def handleInnerNot(not: QExpr.Not)(renames: Map[String, String])(using quotes: Quotes): MongoQueryNode = + private def handleInnerNot(not: QExpr.Not)(renames: Map[String, String])(using quotes: Quotes): MongoQueryNode = import quotes.reflect.* not.x match case QExpr.Gte(x, y) => MQ.OnField(getField(x)(renames), MQ.Not(MQ.Gte(opt(y)))) @@ -123,7 +132,7 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { case QExpr.Nin(x, y) => MQ.OnField(getField(x)(renames), MQ.Not(MQ.Nin(handleArrayConds(y)))) case _ => report.errorAndAbort("Wrong operator inside $not") - def handleArrayConds(x: List[QExpr] | QExpr)(using quotes: Quotes): List[MQ] | MQ = + private def handleArrayConds(x: List[QExpr] | QExpr)(using quotes: Quotes): List[MQ] | MQ = x match case list: List[QExpr @unchecked] => list map opt case expr: QExpr => opt(expr) @@ -132,15 +141,22 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { import quotes.reflect.* def rec(node: MongoQueryNode)(using quotes: Quotes): String = node match - case MQ.OnField(prop, x) => "\"" + prop.path + "\"" + ": " + rec(x) - case MQ.Gte(x) => "{ \"$gte\": " + rec(x) + " }" - case MQ.Lte(x) => "{ \"$lte\": " + rec(x) + " }" - case MQ.Gt(x) => "{ \"$gt\": " + rec(x) + " }" - case MQ.Lt(x) => "{ \"$lt\": " + rec(x) + " }" - case MQ.Eq(x) => rec(x) - case MQ.Ne(x) => "{ \"$ne\": " + rec(x) + " }" - case MQ.Not(x) => "{ \"$not\": " + rec(x) + " }" - case MQ.Size(x) => "{ \"$size\": " + rec(x) + " }" + case MQ.OnField(prop, x) if prop.path.nonEmpty => "\"" + prop.path + "\"" + ": " + rec(x) + case MQ.OnField(_, x) => // when querying array of primitives + val res = x match + case MQ.Eq(x) => "{ \"$eq\": " + rec(x) + " }" + // fixes querying array of primitives with `==` and smth else (wrong query, but has to compile anyway) + case other => rec(other) + if res.startsWith("{ ") then res.drop(2).dropRight(2) + else res + case MQ.Gte(x) => "{ \"$gte\": " + rec(x) + " }" + case MQ.Lte(x) => "{ \"$lte\": " + rec(x) + " }" + case MQ.Gt(x) => "{ \"$gt\": " + rec(x) + " }" + case MQ.Lt(x) => "{ \"$lt\": " + rec(x) + " }" + case MQ.Eq(x) => rec(x) + case MQ.Ne(x) => "{ \"$ne\": " + rec(x) + " }" + case MQ.Not(x) => "{ \"$not\": " + rec(x) + " }" + case MQ.Size(x) => "{ \"$size\": " + rec(x) + " }" case MQ.Regex(pattern) => pattern.value match case Some(p: Pattern) => @@ -151,12 +167,12 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { case MQ.In(exprs) => "{ \"$in\": [" + renderArrays(exprs) + "] }" case MQ.Nin(exprs) => "{ \"$nin\": [" + renderArrays(exprs) + "] }" case MQ.And(exprs) => - val fields = exprs.collect { case q: MQ.OnField => q.field.path.mkString(".") } - if (fields.distinct.size < fields.size) + val fields = exprs.collect { case q: MQ.OnField if q.field.path.nonEmpty => q.field.path.mkString(".") } + if ((fields.distinct.size < fields.size) && fields.nonEmpty) "\"$and\": [ " + exprs.map(rec).map("{ " + _ + " }").mkString(", ") + " ]" else exprs.map(rec).mkString(", ") case MQ.Or(exprs) => "\"$or\": [ " + exprs.map(rec).map("{ " + _ + " }").mkString(", ") + " ]" - case MQ.Exists(x) => " { \"$exists\": " + rec(x) + " }" + case MQ.Exists(x) => "{ \"$exists\": " + rec(x) + " }" case MQ.Constant(s: String) => "\"" + s + "\"" case MQ.Constant(s: Any) => s.toString // also limit case MQ.ScalaCode(code) => renderCode(code) @@ -164,6 +180,8 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { case MQ.Subquery(_) => "{...}" case MQ.TypeCheck(bsonType) => "{ \"$type\": " + rec(bsonType) + " }" case MQ.Mod(divisor, remainder) => "{ \"$mod\": [" + rec(divisor) + "," + rec(remainder) + "] }" + case MQ.ElemMatch(expr) => + "{ \"$elemMatch\": { " + rec(expr) + " } }" case MQ.Field(field) => report.errorAndAbort(s"There is no filter condition on field ${field.mkString(".")}") end rec @@ -182,17 +200,24 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { override def target(optRepr: MongoQueryNode)(using quotes: Quotes): Expr[BsonDocument] = import quotes.reflect.* optRepr match { - case and: MQ.And => handleAnd(and) - case or: MQ.Or => handleOr(or) - case MQ.OnField(prop, x) => '{ BsonDocument(${ Expr(prop.path) } -> ${ parseOperators(x) }) } - case MQ.Subquery(doc) => doc - case _ => report.errorAndAbort("given node can't be in that position") + case and: MQ.And => handleAnd(and) + case or: MQ.Or => handleOr(or) + case MQ.OnField(prop, x) if prop.path.nonEmpty => + '{ BsonDocument(${ Expr(prop.path) } -> ${ parseOperators(x) }) } + case MQ.OnField(_, x) => parseOperatorsAsBsonDocument(x) + case MQ.Subquery(doc) => doc + case _ => report.errorAndAbort(s"given node can't be in that position ${optRepr}") } - def parseOperators(optRepr: MongoQueryNode)(using quotes: Quotes): Expr[BsonValue] = + private def parseOperators(optRepr: MongoQueryNode)(using quotes: Quotes): Expr[BsonValue] = + parseEq.lift(optRepr).getOrElse(parseOperatorsAsBsonDocument(optRepr)) + private def parseEq(using quotes: Quotes): PartialFunction[MQ, Expr[BsonValue]] = + case MQ.Eq(x) => handleValues(x) + private def parseOperatorsAsBsonDocument(optRepr: MongoQueryNode)(using quotes: Quotes): Expr[BsonDocument] = import quotes.reflect.* optRepr match - case MQ.Eq(x) => handleValues(x) + case MQ.Eq(x) => + '{ BsonDocument("$eq" -> ${ handleValues(x) }) } case MQ.Gte(x) => '{ BsonDocument("$gte" -> ${ handleValues(x) }) } case MQ.Lte(x) => @@ -238,10 +263,15 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { '{ BsonDocument("$mod" -> BsonArray.fromIterable(List(${ handleValues(divisor) }, ${ handleValues(remainder) }))) } + case MQ.ElemMatch(expr) => + '{ + BsonDocument( + "$elemMatch" -> ${ target(expr) } + ) + } case _ => report.errorAndAbort(s"Wrong operator: ${optRepr}") - end parseOperators - def handleArrayCond(x: List[MQ] | MQ)(using q: Quotes): Expr[BsonValue] = + private def handleArrayCond(x: List[MQ] | MQ)(using q: Quotes): Expr[BsonValue] = import q.reflect.* x match case list: List[MQ @unchecked] => @@ -259,7 +289,7 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { } case _ => report.errorAndAbort("Incorrect condition for array") - def handleAnd(and: MQ.And)(using q: Quotes): Expr[BsonDocument] = + private def handleAnd(and: MQ.And)(using q: Quotes): Expr[BsonDocument] = '{ val exprs: List[BsonDocument] = ${ Expr.ofList(and.exprs.map(target)) } if (exprs.flatMap(_.keySet().asScala).distinct.size < exprs.size) @@ -267,14 +297,14 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { else BsonDocument(exprs.map(_.asScala.toList).foldLeft(List.empty[(String, BsonValue)])(_ ++ _)) } - def handleOr(or: MQ.Or)(using q: Quotes): Expr[BsonDocument] = + private def handleOr(or: MQ.Or)(using q: Quotes): Expr[BsonDocument] = '{ BsonDocument("$or" -> BsonArray.fromIterable(${ Expr.ofList(or.exprs.map(target)) })) } - def handleValues(expr: MongoQueryNode)(using q: Quotes): Expr[BsonValue] = + private def handleValues(expr: MongoQueryNode)(using q: Quotes): Expr[BsonValue] = import q.reflect.* expr match { case MQ.Constant(i: Long) => @@ -299,13 +329,7 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { case _ => report.errorAndAbort(s"Given type is not literal constant") } - def extractField(expr: MongoQueryNode)(using q: Quotes): Expr[String] = - import q.reflect.* - expr match - case MQ.Field(path) => Expr(path) - case _ => report.errorAndAbort("field should be string") - - def parsePattern(pattern: Pattern): (String, Option[String]) = + private def parsePattern(pattern: Pattern): (String, Option[String]) = val flags = List( if (pattern.flags & Pattern.CASE_INSENSITIVE) != 0 then Some("i") else None, if (pattern.flags & Pattern.MULTILINE) != 0 then Some("m") else None, @@ -317,6 +341,19 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] { val matcher = Pattern.compile("(\\(\\?([a-z]*)\\))?(.*)").matcher(pattern.pattern) matcher.matches() matcher.group(3) -> options - end parsePattern + override def optimize(query: MQ)(using quotes: Quotes): MQ = + val opt: PartialFunction[MQ, MQ] = { case q @ MQ.OnField(_, _: MQ.ElemMatch) => + optimizeElemMatch(q) + + } + opt.lift(query).getOrElse(query) + + private def optimizeElemMatch(elemMatch: MQ.OnField)(using quotes: Quotes): MQ = + import quotes.reflect.* + elemMatch match + case q @ MQ.OnField(_, MQ.ElemMatch(_: MQ.And | _: MQ.Or)) => q + case MQ.OnField(first, MQ.ElemMatch(MQ.OnField(second, expr))) => + optimize(MQ.OnField(MQ.Field(Vector(first.path, second.path).filter(_.nonEmpty).mkString(".")), expr)) // flatten $elemMatch for querying one field + case _ => report.errorAndAbort(s"Not a ElemMatch: ${elemMatch}") } diff --git a/oolong-mongo/src/main/scala/oolong/mongo/MongoQueryNode.scala b/oolong-mongo/src/main/scala/oolong/mongo/MongoQueryNode.scala index 88ee480..c851e5b 100644 --- a/oolong-mongo/src/main/scala/oolong/mongo/MongoQueryNode.scala +++ b/oolong-mongo/src/main/scala/oolong/mongo/MongoQueryNode.scala @@ -42,4 +42,6 @@ case object MongoQueryNode { case class TypeCheck(bsonType: Constant[Int]) extends MQ case class Mod(divisor: MQ, remainder: MQ) extends MQ + + case class ElemMatch(x: MQ) extends MQ } diff --git a/oolong-mongo/src/test/scala/oolong/mongo/QuerySpec.scala b/oolong-mongo/src/test/scala/oolong/mongo/QuerySpec.scala index 0ecb0b5..8b9c395 100644 --- a/oolong-mongo/src/test/scala/oolong/mongo/QuerySpec.scala +++ b/oolong-mongo/src/test/scala/oolong/mongo/QuerySpec.scala @@ -925,6 +925,102 @@ class QuerySpec extends AnyFunSuite { } + test("$elemMatch for array containing objects") { + case class Inner(a: Int, b: String) + case class Test(array: Vector[Inner]) + + val q = query[Test](_.array.exists(s => s.a > 2 && s.b == "123")) + val repr = renderQuery[Test](_.array.exists(s => s.a > 2 && s.b == "123")) + + test( + q, + repr, + BsonDocument( + "array" -> BsonDocument( + "$elemMatch" -> BsonDocument( + "a" -> BsonDocument("$gt" -> BsonInt32(2)), + "b" -> BsonString("123") + ) + ) + ) + ) + } + + test("$elemMatch for array primitives") { + case class Test(array: Vector[Int]) + + val q = query[Test](_.array.exists(s => s > 2 && s <= 100)) + val repr = renderQuery[Test](_.array.exists(s => s > 2 && s <= 100)) + + test( + q, + repr, + BsonDocument( + "array" -> BsonDocument( + "$elemMatch" -> BsonDocument( + "$gt" -> BsonInt32(2), + "$lte" -> BsonInt32(100) + ) + ) + ) + ) + } + + test("$elemMatch for array with $and with same field twice") { + case class Inner(a: Int, b: String) + case class Test(array: Vector[Inner]) + + val q = query[Test](_.array.exists(s => s.a > 2 && s.a < 100 && s.b == "123")) + val repr = renderQuery[Test](_.array.exists(s => s.a > 2 && s.a < 100 && s.b == "123")) + + test( + q, + repr, + BsonDocument( + "array" -> BsonDocument( + "$elemMatch" -> BsonDocument( + "$and" -> BsonArray.fromIterable( + List( + BsonDocument("a" -> BsonDocument("$gt" -> BsonInt32(2))), + BsonDocument("a" -> BsonDocument("$lt" -> BsonInt32(100))), + BsonDocument("b" -> BsonString("123")), + ) + ) + ) + ) + ) + ) + } + + test("$elemMatch querying a single field") { + case class Test(array: Vector[Int]) + + val q = query[Test](_.array.exists(s => s < 100)) + val repr = renderQuery[Test](_.array.exists(s => s < 100)) + + test( + q, + repr, + BsonDocument("array" -> BsonDocument("$lt" -> BsonInt32(100))) + ) + } + + test("nested $elemMatch") { + case class Inner(array1: Vector[Int]) + case class Base(array0: Vector[Inner]) + + val q = query[Base](_.array0.exists(_.array1.exists(_ > 100))) + val repr = renderQuery[Base](_.array0.exists(_.array1.exists(_ > 100))) + + test( + q, + repr, + BsonDocument( + "array0.array1" -> BsonDocument("$gt" -> BsonInt32(100)) + ) + ) + } + private inline def test( query: BsonDocument, repr: String,