Skip to content

Commit

Permalink
Fix unionAll bug opencypher#402
Browse files Browse the repository at this point in the history
Co-authored-by: Tobias Johansson <[email protected]>
  • Loading branch information
soerenreichardt and tobias-johansson committed Oct 30, 2018
1 parent 827d254 commit 7b438e2
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ trait RelationalCypherGraphFactory[T <: Table[T]] {
(implicit context: RelationalRuntimeContext[T]): Graph = new SingleTableGraph(drivingTable, schema, tagsUsed)

def unionGraph(graphs: RelationalCypherGraph[T]*)(implicit context: RelationalRuntimeContext[T]): Graph = {
unionGraph(computeRetaggings(graphs.map(g => g -> g.tags).toMap))
unionGraph(computeRetaggings(graphs.map(g => g -> g.tags)).toList)
}

def unionGraph(graphsToReplacements: Map[RelationalCypherGraph[T], Map[Int, Int]])
def unionGraph(graphsToReplacements: List[(RelationalCypherGraph[T], Map[Int, Int])])
(implicit context: RelationalRuntimeContext[T]): Graph = UnionGraph(graphsToReplacements)

def empty: Graph = EmptyGraph()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ object TagSupport {
* graphs the required retaggings are computed on top of the fixed retaggings.
*/
def computeRetaggings[GraphKey](
graphs: Map[GraphKey, Set[Int]],
fixedRetaggings: Map[GraphKey, Map[Int, Int]] = Map.empty[GraphKey, Map[Int, Int]]
): Map[GraphKey, Map[Int, Int]] = {
graphs: Seq[(GraphKey, Set[Int])],
fixedRetaggings: Seq[(GraphKey, Map[Int, Int])] = Seq.empty
): Seq[(GraphKey, Map[Int, Int])] = {
val graphsToRetag = graphs.filterNot { case (qgn, _) => fixedRetaggings.contains(qgn) }
val usedTags = fixedRetaggings.values.flatMap(_.values).toSet
val usedTags = fixedRetaggings.map(_._2).flatMap(_.values).toSet
val (result, _) = graphsToRetag.foldLeft((fixedRetaggings, usedTags)) {
case ((graphReplacements, previousTags), (graphId, rightTags)) =>

val replacements = previousTags.replacementsFor(rightTags)
val updatedRightTags = rightTags.replaceWith(replacements)

val updatedPreviousTags = previousTags ++ updatedRightTags
val updatedGraphReplacements = graphReplacements.updated(graphId, replacements)
val updatedGraphReplacements = graphReplacements :+ ((graphId, replacements))

updatedGraphReplacements -> updatedPreviousTags
(updatedGraphReplacements, updatedPreviousTags)
}
result
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ import org.opencypher.okapi.relational.impl.planning.RelationalPlanner._
import scala.reflect.runtime.universe.TypeTag

// TODO: This should be a planned tree of physical operators instead of a graph
final case class UnionGraph[T <: Table[T] : TypeTag](graphsToReplacements: Map[RelationalCypherGraph[T], Map[Int, Int]])
final case class UnionGraph[T <: Table[T] : TypeTag](graphsToReplacements: Seq[(RelationalCypherGraph[T], Map[Int, Int])])
(implicit context: RelationalRuntimeContext[T]) extends RelationalCypherGraph[T] {

private val (graphs, replacements) = graphsToReplacements.unzip

override implicit val session: RelationalCypherSession[T] = context.session

override type Records = RelationalCypherRecords[T]
Expand All @@ -50,12 +52,12 @@ final case class UnionGraph[T <: Table[T] : TypeTag](graphsToReplacements: Map[R

require(graphsToReplacements.nonEmpty, "Union requires at least one graph")

override def tables: Seq[T] = graphsToReplacements.keys.flatMap(_.tables).toSeq
override def tables: Seq[T] = graphs.flatMap(_.tables)

override lazy val tags: Set[Int] = graphsToReplacements.values.flatMap(_.values).toSet
override lazy val tags: Set[Int] = replacements.flatMap(_.values).toSet

override lazy val schema: Schema = {
graphsToReplacements.keys.map(g => g.schema).foldLeft(Schema.empty)(_ ++ _)
graphs.map(g => g.schema).foldLeft(Schema.empty)(_ ++ _)
}

override def toString = s"UnionGraph(graphs=[${graphsToReplacements.mkString(",")}])"
Expand All @@ -66,11 +68,12 @@ final case class UnionGraph[T <: Table[T] : TypeTag](graphsToReplacements: Map[R
): RelationalOperator[T] = {
val targetEntity = Var("")(entityType)
val targetEntityHeader = schema.headerForEntity(targetEntity, exactLabelMatch)
val alignedScans = graphsToReplacements.keys
.map { graph =>
val scanOp = graph.scanOperator(entityType, exactLabelMatch)
val retagOp = scanOp.retagVariable(targetEntity, graphsToReplacements(graph))
retagOp.alignWith(targetEntity, targetEntityHeader)
val alignedScans = graphsToReplacements
.map {
case (graph, replacement) =>
val scanOp = graph.scanOperator(entityType, exactLabelMatch)
val retagOp = scanOp.retagVariable(targetEntity, replacement)
retagOp.alignWith(targetEntity, targetEntityHeader)
}
// TODO: find out if a graph returns empty records and skip union operation
Distinct(alignedScans.reduce(TabularUnionAll(_, _)), Set(targetEntity))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,12 @@ final case class GraphUnionAll[T <: Table[T] : TypeTag](

import org.opencypher.okapi.relational.api.tagging.TagSupport._

override lazy val tagStrategy: Map[QualifiedGraphName, Map[Int, Int]] = computeRetaggings(inputs.toList.map(r => r.graphName -> r.graph.tags).toMap)
override lazy val tagStrategy: Map[QualifiedGraphName, Map[Int, Int]] = computeRetaggings(inputs.toList.map(r => r.graphName -> r.graph.tags)).toMap

override lazy val graphName: QualifiedGraphName = qgn

override lazy val graph: RelationalCypherGraph[T] = {
val graphWithTagStrategy = inputs.toList.map(i => i.graph -> tagStrategy(i.graphName)).toMap
val graphWithTagStrategy = inputs.toList.map(i => i.graph -> tagStrategy(i.graphName))
session.graphs.unionGraph(graphWithTagStrategy)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ object ConstructGraphPlanner {
construct.onGraphs match {
case Nil => relational.Start[T](context.session.emptyGraphQgn) // Empty start
case one :: Nil => // Just one graph, no union required
relational.Start(one, tagStrategy = computeRetaggings(Map(one -> context.resolveGraph(one).tags)))
relational.Start(one, tagStrategy = computeRetaggings(List(one -> context.resolveGraph(one).tags)).toMap)
case several =>
val onGraphPlans = NonEmptyList.fromListUnsafe(several).map(qgn => relational.Start[T](qgn))
relational.GraphUnionAll[T](onGraphPlans, construct.qualifiedGraphName)
Expand All @@ -73,7 +73,7 @@ object ConstructGraphPlanner {
val allGraphs = unionTagStrategy.keySet ++ matchGraphs
val tagsForGraph: Map[QualifiedGraphName, Set[Int]] = allGraphs.map(qgn => qgn -> context.resolveGraph(qgn).tags).toMap

val constructTagStrategy = computeRetaggings(tagsForGraph, unionTagStrategy)
val constructTagStrategy = computeRetaggings(tagsForGraph.toSeq, unionTagStrategy.toSeq).toMap

// Apply aliases in CLONE to input table in order to create the base table, on which CONSTRUCT happens
val aliasClones = clonedVarsToInputVars
Expand Down Expand Up @@ -131,7 +131,7 @@ object ConstructGraphPlanner {
val graph = if (onGraph == context.session.graphs.empty) {
context.session.graphs.unionGraph(patternGraph)
} else {
context.session.graphs.unionGraph(Map(identityRetaggings(onGraph), identityRetaggings(patternGraph)))
context.session.graphs.unionGraph(List(identityRetaggings(onGraph), identityRetaggings(patternGraph)))
}

val constructOp = ConstructGraph(inputTablePlan, graph, name, constructTagStrategy, construct, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@
package org.opencypher.spark.impl

import org.apache.spark.sql.Row
import org.opencypher.okapi.relational.api.graph.RelationalCypherGraph
import org.opencypher.okapi.relational.api.tagging.Tags._
import org.opencypher.okapi.testing.Bag
import org.opencypher.spark.impl.table.SparkTable.DataFrameTable
import org.opencypher.spark.testing.fixture.{GraphConstructionFixture, RecordsVerificationFixture, TeamDataFixture}

class UnionGraphTest extends CAPSGraphTest
Expand All @@ -48,6 +46,12 @@ class UnionGraphTest extends CAPSGraphTest
testGraph1.unionAll(testGraph2).cypher("""MATCH (n) RETURN DISTINCT id(n)""").records.size should equal(2)
}

it("supports UNION ALL on identical graphs") {
val g = initGraph("CREATE ()")
val union = g.unionAll(g)
union.nodes("n").size shouldBe 2
}

test("Node scan from single node CAPSRecords") {
val inputGraph = initGraph(`:Person`)
val inputNodes = inputGraph.nodes("n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ class MultipleGraphBehaviour extends CAPSTestSuite with ScanGraphInit {
case Some(relPlan) =>
val switchOp = relPlan.collectFirst { case op: SwitchContext[_] => op }.get
val containsUnionGraph = switchOp.context.queryLocalCatalog.head._2 match {
case g: UnionGraph[_] => g.graphsToReplacements.keys.collectFirst { case op: UnionGraph[_] => op }.isDefined
case g: UnionGraph[_] => g.graphsToReplacements.unzip._1.collectFirst { case op: UnionGraph[_] => op }.isDefined
case _ => false
}
withClue("CONSTRUCT plans union on a single input graph") {
Expand Down

0 comments on commit 7b438e2

Please sign in to comment.