Skip to content

Commit

Permalink
[query] Lowering + Optimisation with implict timing context
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 29, 2025
1 parent d14286b commit 5a64ac4
Show file tree
Hide file tree
Showing 27 changed files with 765 additions and 778 deletions.
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ object LocalBackend extends Backend {
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class ServiceBackend(
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class SparkBackend(val sc: SparkContext) extends Backend {
ctx.time {
TypeCheck(ctx, ir)
Validate(ir)
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
try {
val lowerTable = ctx.flags.get("lower") != null
val lowerBM = ctx.flags.get("lower_bm") != null
Expand Down
10 changes: 5 additions & 5 deletions hail/hail/src/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ abstract class BaseIR {
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
var mark: Int = 0

def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean =
/* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right
* names */
NormalizeNames(ctx, this, allowFreeVariables = true) ==
NormalizeNames(ctx, other, allowFreeVariables = true)
def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean = {
// FIXME: rewrite to not rebuild the irs by maintaining an env mapping left to right names
val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true)
normalize(ctx, this) == normalize(ctx, other)
}

def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
Expand Down
2 changes: 0 additions & 2 deletions hail/hail/src/is/hail/expr/ir/Compilable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ object InterpretableButNotCompilable {
case _: MatrixToValueApply => true
case _: BlockMatrixToValueApply => true
case _: BlockMatrixCollect => true
case _: BlockMatrixToTableApply => true
case _ => false
}
}
Expand All @@ -44,7 +43,6 @@ object Compilable {
case _: TableToValueApply => false
case _: MatrixToValueApply => false
case _: BlockMatrixToValueApply => false
case _: BlockMatrixToTableApply => false
case _: RelationalRef => false
case _: RelationalLet => false
case _ => true
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ object compile {
N: sourcecode.Name,
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) =
ctx.time {
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
val normalizedBody = NormalizeNames(allowFreeVariables = true)(ctx, body)
ctx.CodeCache.getOrElseUpdate(
CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), {
var ir = Subst(
Expand Down
73 changes: 37 additions & 36 deletions hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,52 +25,53 @@ object ExtractIntervalFilters {

val MAX_LITERAL_SIZE = 4096

def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = {
MapIR.mapBaseIR(
ir0,
(ir: BaseIR) => {
(
ir match {
case TableFilter(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(TableIR.rowName, child.typ.rowType),
child.typ.key,
)
.map { case (newCond, intervals) =>
def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR =
ctx.time {
MapIR.mapBaseIR(
ir0,
ir =>
(
ir match {
case TableFilter(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(TableIR.rowName, child.typ.rowType),
child.typ.key,
)
.map { case (newCond, intervals) =>
log.info(
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
}
case MatrixFilterRows(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(MatrixIR.rowName, child.typ.rowType),
child.typ.rowKey,
).map { case (newCond, intervals) =>
log.info(
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
}
case MatrixFilterRows(child, pred) => extractPartitionFilters(
ctx,
pred,
Ref(MatrixIR.rowName, child.typ.rowType),
child.typ.rowKey,
).map { case (newCond, intervals) =>
log.info(
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
}

case _ => None
}
).getOrElse(ir)
},
)
}
case _ => None
}
).getOrElse(ir),
)
}

def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String])
: Option[(IR, IndexedSeq[Interval])] = {
if (key.isEmpty) None
else {
else ctx.time {
val extract =
new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key))
val trueSet = extract.analyze(cond, ref.name)
Expand Down
4 changes: 3 additions & 1 deletion hail/hail/src/is/hail/expr/ir/FoldConstants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import is.hail.utils.HailException

object FoldConstants {
def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR =
ctx.r.pool.scopedRegion(region => ctx.local(r = region)(foldConstants(_, ir)))
ctx.time {
ctx.r.pool.scopedRegion(r => ctx.local(r = r)(foldConstants(_, ir)))
}

private def foldConstants(ctx: ExecuteContext, ir: BaseIR): BaseIR =
RewriteBottomUp(
Expand Down
95 changes: 53 additions & 42 deletions hail/hail/src/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ package is.hail.expr.ir
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.defs.{BaseRef, Binding, Block, In, Ref, Str}
import is.hail.types.virtual.TVoid
import is.hail.utils.BoxedArrayBuilder
import is.hail.utils.{fatal, BoxedArrayBuilder}

import scala.collection.Set
import scala.util.control.NonFatal

object ForwardLets {
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
val ir1 = NormalizeNames(ctx, ir0, allowFreeVariables = true)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ir1)

def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = {
def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T =
ctx.time {
val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ctx, ir1)

def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Int)
: Boolean = {
: Boolean =
IsPure(value) && (
value.isInstanceOf[Ref] ||
value.isInstanceOf[In] ||
Expand All @@ -28,45 +28,56 @@ object ForwardLets {
!ContainsAgg(value)) &&
!ContainsAggIntermediate(value)
)
}

ir match {
case l: Block =>
val keep = new BoxedArrayBuilder[Binding]
val refs = uses(l)
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
} else {
keep += Binding(name, rewriteValue, scope)
env
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR =
ir match {
case l: Block =>
val keep = new BoxedArrayBuilder[Binding]
val refs = uses(l)
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
} else {
keep += Binding(name, rewriteValue, scope)
env
}
}

val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
if (keep.isEmpty) newBody
else Block(keep.result(), newBody)

case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
else forwarded
}
}
.getOrElse(x)
case _ =>
ir.mapChildrenWithIndex((ir1, i) =>
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
)
}

val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
if (keep.isEmpty) newBody
else Block(keep.result(), newBody)
val ir = rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)))

case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
else forwarded
}
.getOrElse(x)
case _ =>
ir.mapChildrenWithIndex((ir1, i) =>
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
try
TypeCheck(ctx, ir)
catch {
case NonFatal(e) =>
fatal(
s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir0, preserveNames = true)}",
e,
)
}
}

rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))).asInstanceOf[T]
}
ir.asInstanceOf[T]
}
}
Loading

0 comments on commit 5a64ac4

Please sign in to comment.