Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add $elemMatch #14

Merged
merged 3 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ val q3 = query[Person](p => Pattern.matches("^[\\w-\\.]+@([\\w-]+\\.)+[\\w-]{2,4

V Array Query Operators

1. $size

```scala
import oolong.dsl.*

Expand All @@ -204,6 +206,21 @@ val q = query[Course](_.studentNames.length == 20)
// q is {"studentNames": {"$size": 20}}
```

2. $elemMatch

```scala
import oolong.dsl.*

case class Course(studentNames: List[String], tutor: String)

val q = query[Course](_.studentNames.exists(_ == 20)) // $elemMatch ommited when querying single field
// q is {"studentNames": 20}

val q = query[Course](course => course.studentNames.exists(_ > 20) && course.tutor == "Pavlov")
// q is {"studentNames": {"$elemMatch": {"studentNames": {"$gt": 20}, "tutor": "Pavlov"}}}

```

#### Update operators

I Field Update Operators
Expand Down
4 changes: 3 additions & 1 deletion oolong-core/src/main/scala/oolong/AstParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import scala.annotation.tailrec
import scala.language.postfixOps
import scala.quoted.*

import oolong.AstParser
import oolong.UExpr.FieldUpdateExpr
import oolong.Utils.*
import oolong.dsl.*
Expand Down Expand Up @@ -63,6 +62,9 @@ private[oolong] class DefaultAstParser(using quotes: Quotes) extends AstParser {
case '{ ($x: Seq[_]).length == ($y: Int) } =>
QExpr.Size(parse(x), parse(y))

case '{ type t; ($x: Seq[`t`]).exists($y: (`t` => Boolean)) } => // not text & where
QExpr.ElemMatch(parse(x), parseQExpr(y))

case AsTerm(Apply(Select(lhs, "<="), List(rhs))) =>
QExpr.Lte(parse(lhs.asExpr), parse(rhs.asExpr))

Expand Down
2 changes: 1 addition & 1 deletion oolong-core/src/main/scala/oolong/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ private[oolong] trait Backend[Ast, OptimizableRepr, TargetRepr] {
/**
* Perform optimizations that are specific to this backend.
*/
def optimize(query: OptimizableRepr): OptimizableRepr = query
def optimize(query: OptimizableRepr)(using quotes: Quotes): OptimizableRepr = query
}
11 changes: 6 additions & 5 deletions oolong-core/src/main/scala/oolong/LogicalOptimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ private[oolong] object LogicalOptimizer {
}

ast match {
case QExpr.And(children) => flatten(QExpr.And(children.map(optimize)))
case QExpr.Or(children) => flatten(QExpr.Or(children.map(optimize)))
case QExpr.Not(QExpr.Not(e)) => e
case QExpr.Not(QExpr.Eq(l, r)) => QExpr.Ne(l, r)
case _ => ast
case QExpr.ElemMatch(field, expr) => QExpr.ElemMatch(field, flatten(expr))
case QExpr.And(children) => flatten(QExpr.And(children.map(optimize)))
case QExpr.Or(children) => flatten(QExpr.Or(children.map(optimize)))
case QExpr.Not(QExpr.Not(e)) => e
case QExpr.Not(QExpr.Eq(l, r)) => QExpr.Ne(l, r)
case _ => ast
}
}
}
2 changes: 2 additions & 0 deletions oolong-core/src/main/scala/oolong/QExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ private[oolong] object QExpr {
case class TypeCheck[T](x: QExpr, typeInfo: TypeInfo[T]) extends QExpr

case class Mod(x: QExpr, divisor: QExpr, remainder: QExpr) extends QExpr

case class ElemMatch(x: QExpr, y: QExpr) extends QExpr
}
111 changes: 74 additions & 37 deletions oolong-mongo/src/main/scala/oolong/mongo/MongoQueryCompiler.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package oolong.mongo

import java.util.regex.Pattern
import scala.annotation.tailrec
import scala.jdk.CollectionConverters.*
import scala.quoted.Expr
import scala.quoted.Quotes
import scala.quoted.Type
import scala.util.chaining.*

import oolong.*
import oolong.TypeInfo
import oolong.Utils.PatternInstance.given
import oolong.bson.*
import oolong.bson.meta.QueryMeta
Expand Down Expand Up @@ -96,19 +97,27 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
report.errorAndAbort(s"Expected the subquery inside 'unchecked(...)' to have 'org.mongodb.scala.bson.BsonDocument' type, but the subquery is '${code.show}'")
}
case not: QExpr.Not => handleInnerNot(not)(renames)
case elemMatch: QExpr.ElemMatch =>
val cond = rec(elemMatch.y, renames)
val query = cond match
case f: MQ.OnField => MQ.ElemMatch(f)
case and: MQ.And => MQ.ElemMatch(and)
case or: MQ.Or => MQ.ElemMatch(or)
case _ => report.errorAndAbort(s"Wrong condition: ${cond}")
MQ.OnField(getField(elemMatch.x)(renames), query)
}

rec(ast, meta)

}

def getField(f: QExpr)(renames: Map[String, String])(using quotes: Quotes): MQ.Field =
private def getField(f: QExpr)(renames: Map[String, String])(using quotes: Quotes): MQ.Field =
import quotes.reflect.*
f match
case QExpr.Prop(path) => MQ.Field(renames.getOrElse(path, path))
case _ => report.errorAndAbort("Field is of wrong type")
case expr => report.errorAndAbort(s"Field is of wrong type: ${expr}")

def handleInnerNot(not: QExpr.Not)(renames: Map[String, String])(using quotes: Quotes): MongoQueryNode =
private def handleInnerNot(not: QExpr.Not)(renames: Map[String, String])(using quotes: Quotes): MongoQueryNode =
import quotes.reflect.*
not.x match
case QExpr.Gte(x, y) => MQ.OnField(getField(x)(renames), MQ.Not(MQ.Gte(opt(y))))
Expand All @@ -123,7 +132,7 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
case QExpr.Nin(x, y) => MQ.OnField(getField(x)(renames), MQ.Not(MQ.Nin(handleArrayConds(y))))
case _ => report.errorAndAbort("Wrong operator inside $not")

def handleArrayConds(x: List[QExpr] | QExpr)(using quotes: Quotes): List[MQ] | MQ =
private def handleArrayConds(x: List[QExpr] | QExpr)(using quotes: Quotes): List[MQ] | MQ =
x match
case list: List[QExpr @unchecked] => list map opt
case expr: QExpr => opt(expr)
Expand All @@ -132,15 +141,22 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
import quotes.reflect.*
def rec(node: MongoQueryNode)(using quotes: Quotes): String =
node match
case MQ.OnField(prop, x) => "\"" + prop.path + "\"" + ": " + rec(x)
case MQ.Gte(x) => "{ \"$gte\": " + rec(x) + " }"
case MQ.Lte(x) => "{ \"$lte\": " + rec(x) + " }"
case MQ.Gt(x) => "{ \"$gt\": " + rec(x) + " }"
case MQ.Lt(x) => "{ \"$lt\": " + rec(x) + " }"
case MQ.Eq(x) => rec(x)
case MQ.Ne(x) => "{ \"$ne\": " + rec(x) + " }"
case MQ.Not(x) => "{ \"$not\": " + rec(x) + " }"
case MQ.Size(x) => "{ \"$size\": " + rec(x) + " }"
case MQ.OnField(prop, x) if prop.path.nonEmpty => "\"" + prop.path + "\"" + ": " + rec(x)
case MQ.OnField(_, x) => // when querying array of primitives
val res = x match
case MQ.Eq(x) => "{ \"$eq\": " + rec(x) + " }"
// fixes querying array of primitives with `==` and smth else (wrong query, but has to compile anyway)
case other => rec(other)
if res.startsWith("{ ") then res.drop(2).dropRight(2)
else res
case MQ.Gte(x) => "{ \"$gte\": " + rec(x) + " }"
case MQ.Lte(x) => "{ \"$lte\": " + rec(x) + " }"
case MQ.Gt(x) => "{ \"$gt\": " + rec(x) + " }"
case MQ.Lt(x) => "{ \"$lt\": " + rec(x) + " }"
case MQ.Eq(x) => rec(x)
case MQ.Ne(x) => "{ \"$ne\": " + rec(x) + " }"
case MQ.Not(x) => "{ \"$not\": " + rec(x) + " }"
case MQ.Size(x) => "{ \"$size\": " + rec(x) + " }"
case MQ.Regex(pattern) =>
pattern.value match
case Some(p: Pattern) =>
Expand All @@ -151,19 +167,21 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
case MQ.In(exprs) => "{ \"$in\": [" + renderArrays(exprs) + "] }"
case MQ.Nin(exprs) => "{ \"$nin\": [" + renderArrays(exprs) + "] }"
case MQ.And(exprs) =>
val fields = exprs.collect { case q: MQ.OnField => q.field.path.mkString(".") }
if (fields.distinct.size < fields.size)
val fields = exprs.collect { case q: MQ.OnField if q.field.path.nonEmpty => q.field.path.mkString(".") }
if ((fields.distinct.size < fields.size) && fields.nonEmpty)
"\"$and\": [ " + exprs.map(rec).map("{ " + _ + " }").mkString(", ") + " ]"
else exprs.map(rec).mkString(", ")
case MQ.Or(exprs) => "\"$or\": [ " + exprs.map(rec).map("{ " + _ + " }").mkString(", ") + " ]"
case MQ.Exists(x) => " { \"$exists\": " + rec(x) + " }"
case MQ.Exists(x) => "{ \"$exists\": " + rec(x) + " }"
case MQ.Constant(s: String) => "\"" + s + "\""
case MQ.Constant(s: Any) => s.toString // also limit
case MQ.ScalaCode(code) => renderCode(code)
case MQ.ScalaCodeIterable(_) => "?"
case MQ.Subquery(_) => "{...}"
case MQ.TypeCheck(bsonType) => "{ \"$type\": " + rec(bsonType) + " }"
case MQ.Mod(divisor, remainder) => "{ \"$mod\": [" + rec(divisor) + "," + rec(remainder) + "] }"
case MQ.ElemMatch(expr) =>
"{ \"$elemMatch\": { " + rec(expr) + " } }"
case MQ.Field(field) =>
report.errorAndAbort(s"There is no filter condition on field ${field.mkString(".")}")
end rec
Expand All @@ -182,17 +200,24 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
override def target(optRepr: MongoQueryNode)(using quotes: Quotes): Expr[BsonDocument] =
import quotes.reflect.*
optRepr match {
case and: MQ.And => handleAnd(and)
case or: MQ.Or => handleOr(or)
case MQ.OnField(prop, x) => '{ BsonDocument(${ Expr(prop.path) } -> ${ parseOperators(x) }) }
case MQ.Subquery(doc) => doc
case _ => report.errorAndAbort("given node can't be in that position")
case and: MQ.And => handleAnd(and)
case or: MQ.Or => handleOr(or)
case MQ.OnField(prop, x) if prop.path.nonEmpty =>
'{ BsonDocument(${ Expr(prop.path) } -> ${ parseOperators(x) }) }
case MQ.OnField(_, x) => parseOperatorsAsBsonDocument(x)
case MQ.Subquery(doc) => doc
case _ => report.errorAndAbort(s"given node can't be in that position ${optRepr}")
}

def parseOperators(optRepr: MongoQueryNode)(using quotes: Quotes): Expr[BsonValue] =
private def parseOperators(optRepr: MongoQueryNode)(using quotes: Quotes): Expr[BsonValue] =
parseEq.lift(optRepr).getOrElse(parseOperatorsAsBsonDocument(optRepr))
private def parseEq(using quotes: Quotes): PartialFunction[MQ, Expr[BsonValue]] =
case MQ.Eq(x) => handleValues(x)
private def parseOperatorsAsBsonDocument(optRepr: MongoQueryNode)(using quotes: Quotes): Expr[BsonDocument] =
import quotes.reflect.*
optRepr match
case MQ.Eq(x) => handleValues(x)
case MQ.Eq(x) =>
'{ BsonDocument("$eq" -> ${ handleValues(x) }) }
case MQ.Gte(x) =>
'{ BsonDocument("$gte" -> ${ handleValues(x) }) }
case MQ.Lte(x) =>
Expand Down Expand Up @@ -238,10 +263,15 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
'{
BsonDocument("$mod" -> BsonArray.fromIterable(List(${ handleValues(divisor) }, ${ handleValues(remainder) })))
}
case MQ.ElemMatch(expr) =>
'{
BsonDocument(
"$elemMatch" -> ${ target(expr) }
)
}
case _ => report.errorAndAbort(s"Wrong operator: ${optRepr}")
end parseOperators

def handleArrayCond(x: List[MQ] | MQ)(using q: Quotes): Expr[BsonValue] =
private def handleArrayCond(x: List[MQ] | MQ)(using q: Quotes): Expr[BsonValue] =
import q.reflect.*
x match
case list: List[MQ @unchecked] =>
Expand All @@ -259,22 +289,22 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
}
case _ => report.errorAndAbort("Incorrect condition for array")

def handleAnd(and: MQ.And)(using q: Quotes): Expr[BsonDocument] =
private def handleAnd(and: MQ.And)(using q: Quotes): Expr[BsonDocument] =
'{
val exprs: List[BsonDocument] = ${ Expr.ofList(and.exprs.map(target)) }
if (exprs.flatMap(_.keySet().asScala).distinct.size < exprs.size)
BsonDocument("$and" -> BsonArray.fromIterable(exprs))
else BsonDocument(exprs.map(_.asScala.toList).foldLeft(List.empty[(String, BsonValue)])(_ ++ _))
}

def handleOr(or: MQ.Or)(using q: Quotes): Expr[BsonDocument] =
private def handleOr(or: MQ.Or)(using q: Quotes): Expr[BsonDocument] =
'{
BsonDocument("$or" -> BsonArray.fromIterable(${
Expr.ofList(or.exprs.map(target))
}))
}

def handleValues(expr: MongoQueryNode)(using q: Quotes): Expr[BsonValue] =
private def handleValues(expr: MongoQueryNode)(using q: Quotes): Expr[BsonValue] =
import q.reflect.*
expr match {
case MQ.Constant(i: Long) =>
Expand All @@ -299,13 +329,7 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
case _ => report.errorAndAbort(s"Given type is not literal constant")
}

def extractField(expr: MongoQueryNode)(using q: Quotes): Expr[String] =
import q.reflect.*
expr match
case MQ.Field(path) => Expr(path)
case _ => report.errorAndAbort("field should be string")

def parsePattern(pattern: Pattern): (String, Option[String]) =
private def parsePattern(pattern: Pattern): (String, Option[String]) =
val flags = List(
if (pattern.flags & Pattern.CASE_INSENSITIVE) != 0 then Some("i") else None,
if (pattern.flags & Pattern.MULTILINE) != 0 then Some("m") else None,
Expand All @@ -317,6 +341,19 @@ object MongoQueryCompiler extends Backend[QExpr, MQ, BsonDocument] {
val matcher = Pattern.compile("(\\(\\?([a-z]*)\\))?(.*)").matcher(pattern.pattern)
matcher.matches()
matcher.group(3) -> options
end parsePattern

override def optimize(query: MQ)(using quotes: Quotes): MQ =
val opt: PartialFunction[MQ, MQ] = { case q @ MQ.OnField(_, _: MQ.ElemMatch) =>
optimizeElemMatch(q)

}
opt.lift(query).getOrElse(query)

private def optimizeElemMatch(elemMatch: MQ.OnField)(using quotes: Quotes): MQ =
import quotes.reflect.*
elemMatch match
case q @ MQ.OnField(_, MQ.ElemMatch(_: MQ.And | _: MQ.Or)) => q
case MQ.OnField(first, MQ.ElemMatch(MQ.OnField(second, expr))) =>
optimize(MQ.OnField(MQ.Field(Vector(first.path, second.path).filter(_.nonEmpty).mkString(".")), expr)) // flatten $elemMatch for querying one field
case _ => report.errorAndAbort(s"Not a ElemMatch: ${elemMatch}")
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ case object MongoQueryNode {
case class TypeCheck(bsonType: Constant[Int]) extends MQ

case class Mod(divisor: MQ, remainder: MQ) extends MQ

case class ElemMatch(x: MQ) extends MQ
}
Loading
Loading