Skip to content

Commit

Permalink
[compiler] Minor Requiredness Performance Enchancements (#13991)
Browse files Browse the repository at this point in the history
Main change: add `var mark: Int` to `BaseIR`.
On profiling the benchmark `matrix_multi_write_nothing`, I noticed a
significant amount of time was spent
- iterating through zipped arrays in requiredness 
- Adding and removing elements from `HashSet`s.
In fact, half the time spent in requiredness was removing ir nodes from
the `HashSet` set used as the queue! With this change, requiredness runs
like a stabbed rat!

Explanation of `mark`:
This field acts as a flag that analyses can set. For example:
- `HasSharing` can use the field to see if it has visited a node before.
- `Requiredness` uses this field to tell if a node is currently
enqueued.

The `nextFlag` method in `IrMetadata` allows for analyses to get a fresh
value they can set the `mark` field.
This removes the need to traverse the IR after analyses to re-zero every
`mark` field.
  • Loading branch information
ehigham authored Nov 16, 2023
1 parent 001f93a commit eaf4197
Show file tree
Hide file tree
Showing 18 changed files with 110 additions and 71 deletions.
12 changes: 11 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package is.hail.expr.ir

import is.hail.backend.ExecuteContext
import is.hail.types.BaseType
import is.hail.types.virtual.Type
import is.hail.utils.StackSafe._
Expand All @@ -16,7 +17,16 @@ abstract class BaseIR {

def deepCopy(): this.type = copy(newChildren = childrenSeq.map(_.deepCopy())).asInstanceOf[this.type]

lazy val noSharing: this.type = if (HasIRSharing(this)) this.deepCopy() else this
def noSharing(ctx: ExecuteContext): this.type =
if (HasIRSharing(ctx)(this)) this.deepCopy() else this


// For use as a boolean flag by IR passes. Each pass uses a different sentinel value to encode
// "true" (and anything else is false). As long as we maintain the global invariant that no
// two passes use the same sentinel value, this allows us to reuse this field across passes
// without ever having to initialize it at the start of a pass.
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
var mark: Int = 0

def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
Expand Down
18 changes: 9 additions & 9 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ object Compile {
print: Option[PrintWriter] = None
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = {

val normalizeNames = new NormalizeNames(_.toString)
val normalizedBody = normalizeNames(body,
Env(params.map { case (n, _) => n -> n }: _*))
val normalizedBody = new NormalizeNames(_.toString)(ctx, body,
Env(params.map { case (n, _) => n -> n }: _*)
)
val k = CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F](k) {

var ir = body
ir = Subst(ir, BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }))
ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing
ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(ctx, ir, BindingEnv.empty)

Expand Down Expand Up @@ -85,17 +85,17 @@ object CompileWithAggregators {
body: IR,
optimize: Boolean = true
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion)) = {
val normalizeNames = new NormalizeNames(_.toString)
val normalizedBody = normalizeNames(body,
Env(params.map { case (n, _) => n -> n }: _*))
val normalizedBody = new NormalizeNames(_.toString)(ctx, body,
Env(params.map { case (n, _) => n -> n }: _*)
)
val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) {

var ir = body
ir = Subst(ir, BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }))
ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing
ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(ctx, ir, BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })))

Expand Down Expand Up @@ -184,7 +184,7 @@ object CompileIterator {

val outerRegion = outerRegionField

val ir = LoweringPipeline.compileLowerer(true)(ctx, body).asInstanceOf[IR].noSharing
val ir = LoweringPipeline.compileLowerer(true)(ctx, body).asInstanceOf[IR].noSharing(ctx)
TypeCheck(ctx, ir)

var elementAddress: Settable[Long] = null
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object EmitContext {
def analyze(ctx: ExecuteContext, ir: IR, pTypeEnv: Env[PType] = Env.empty): EmitContext = {
ctx.timer.time("EmitContext.analyze") {
val usesAndDefs = ComputeUsesAndDefs(ir, errorIfFreeVariables = false)
val requiredness = Requiredness.apply(ir, usesAndDefs, null, pTypeEnv)
val requiredness = Requiredness(ir, usesAndDefs, ctx, pTypeEnv)
val inLoopCriticalPath = ControlFlowPreventsSplit(ir, ParentPointers(ir), usesAndDefs)
val methodSplits = ComputeMethodSplits(ctx, ir, inLoopCriticalPath)
new EmitContext(ctx, requiredness, usesAndDefs, methodSplits, inLoopCriticalPath, Memo.empty[Unit])
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package is.hail.expr.ir

import is.hail.utils._
import is.hail.backend.ExecuteContext

import scala.collection.mutable

object ForwardLets {
def apply[T <: BaseIR](ir0: T): T = {
val ir1 = new NormalizeNames(_ => genUID(), allowFreeVariables = true).apply(ir0)
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
val ir1 = new NormalizeNames(_ => genUID(), allowFreeVariables = true)(ctx, ir0)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ir1)

Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Interpret.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object Interpret {
apply(tir, ctx, optimize = true)

def apply(tir: TableIR, ctx: ExecuteContext, optimize: Boolean): TableValue = {
val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing
val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing(ctx)
lowered.analyzeAndExecute(ctx).asTableValue(ctx)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ object LowerOrInterpretNonCompilable {
}
}

rewrite(ir.noSharing, mutable.HashMap.empty)
rewrite(ir.noSharing(ctx), mutable.HashMap.empty)
}
}
10 changes: 7 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package is.hail.expr.ir

import is.hail.backend.ExecuteContext
import is.hail.utils.StackSafe._

class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = false) {
Expand All @@ -10,11 +11,14 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean =
normFunction(count)
}

def apply(ir: IR, env: Env[String]): IR = apply(ir.noSharing, BindingEnv(env))
def apply(ctx: ExecuteContext, ir: IR, env: Env[String]): IR =
normalizeIR(ir.noSharing(ctx), BindingEnv(env)).run().asInstanceOf[IR]

def apply(ir: IR, env: BindingEnv[String]): IR = normalizeIR(ir.noSharing, env).run().asInstanceOf[IR]
def apply(ctx: ExecuteContext, ir: IR, env: BindingEnv[String]): IR =
normalizeIR(ir.noSharing(ctx), env).run().asInstanceOf[IR]

def apply(ir: BaseIR): BaseIR = normalizeIR(ir.noSharing, BindingEnv(agg=Some(Env.empty), scan=Some(Env.empty))).run()
def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR =
normalizeIR(ir.noSharing(ctx), BindingEnv(agg=Some(Env.empty), scan=Some(Env.empty))).run()

private def normalizeIR(ir: BaseIR, env: BindingEnv[String], context: Array[String] = Array()): StackFrame[BaseIR] = {

Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Optimize.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ object Optimize {
last = ir
runOpt(FoldConstants(ctx, _), iter, "FoldConstants")
runOpt(ExtractIntervalFilters(ctx, _), iter, "ExtractIntervalFilters")
runOpt(normalizeNames(_), iter, "NormalizeNames")
runOpt(normalizeNames(ctx, _), iter, "NormalizeNames")
runOpt(Simplify(ctx, _), iter, "Simplify")
runOpt(ForwardLets(_), iter, "ForwardLets")
runOpt(ForwardLets(ctx), iter, "ForwardLets")
runOpt(ForwardRelationalLets(_), iter, "ForwardRelationalLets")
runOpt(PruneDeadFields(ctx, _), iter, "PruneDeadFields")

Expand Down
21 changes: 9 additions & 12 deletions hail/src/main/scala/is/hail/expr/ir/RefEquality.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package is.hail.expr.ir

import is.hail.backend.ExecuteContext

import scala.collection.mutable

object RefEquality {
Expand Down Expand Up @@ -61,19 +63,14 @@ class Memo[T] private(val m: mutable.HashMap[RefEquality[BaseIR], T]) {


object HasIRSharing {
def apply(ir: BaseIR): Boolean = {
val m = mutable.HashSet.empty[RefEquality[BaseIR]]

def recur(x: BaseIR): Boolean = {
val re = RefEquality(x)
if (m.contains(re))
true
else {
m.add(re)
x.children.exists(recur)
}
def apply(ctx: ExecuteContext)(ir: BaseIR): Boolean = {
val mark = ctx.irMetadata.nextFlag

for (node <- IRTraversal.levelOrder(ir)) {
if (node.mark == mark) return true
node.mark = mark
}

recur(ir)
false
}
}
30 changes: 26 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/Requiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
type State = Memo[BaseTypeWithRequiredness]
private val cache = Memo.empty[BaseTypeWithRequiredness]
private val dependents = Memo.empty[mutable.Set[RefEquality[BaseIR]]]
private val q = mutable.Set[RefEquality[BaseIR]]()
private[this] val q = new Queue(ctx.irMetadata.nextFlag)

private val defs = Memo.empty[IndexedSeq[BaseTypeWithRequiredness]]
private val states = Memo.empty[IndexedSeq[TypeWithRequiredness]]
Expand Down Expand Up @@ -90,8 +90,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {

def run(): Unit = {
while (q.nonEmpty) {
val node = q.head
q -= node
val node = q.pop()
if (analyze(node.t) && dependents.contains(node)) {
q ++= dependents.lookup(node)
}
Expand Down Expand Up @@ -615,7 +614,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
val eltType = tcoerce[RIterable](requiredness).elementType
eltType.unionFrom(lookup(joinF))
case StreamMultiMerge(as, _) =>
requiredness.union(as.forall(lookup(_).required))
requiredness.union(as.forall(lookup(_).required))
val elt = tcoerce[RStruct](tcoerce[RIterable](requiredness).elementType)
as.foreach { a =>
elt.unionFields(tcoerce[RStruct](tcoerce[RIterable](lookup(a)).elementType))
Expand Down Expand Up @@ -828,4 +827,27 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {

requiredness.probeChangedAndReset()
}


final class Queue(val markFlag: Int) {
private[this] val q = mutable.Queue[RefEquality[BaseIR]]()

def nonEmpty: Boolean =
q.nonEmpty

def pop(): RefEquality[BaseIR] = {
val n = q.dequeue()
n.t.mark = 0
n
}

def +=(re: RefEquality[BaseIR]): Unit =
if (re.t.mark != markFlag) {
re.t.mark = markFlag
q += re
}

def ++=(res: Iterable[RefEquality[BaseIR]]): Unit =
res.foreach(this += _)
}
}
13 changes: 7 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/Simplify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ object Simplify {
private[this] def simplifyValue(ctx: ExecuteContext): IR => IR =
visitNode(
Simplify(ctx, _),
rewriteValueNode,
simplifyValue(ctx))
rewriteValueNode(ctx),
simplifyValue(ctx)
)

private[this] def simplifyTable(ctx: ExecuteContext)(tir: TableIR): TableIR =
visitNode(
Expand All @@ -55,8 +56,8 @@ object Simplify {
)(bmir)
}

private[this] def rewriteValueNode(ir: IR): Option[IR] =
valueRules.lift(ir).orElse(numericRules(ir))
private[this] def rewriteValueNode(ctx: ExecuteContext)(ir: IR): Option[IR] =
valueRules(ctx).lift(ir).orElse(numericRules(ir))

private[this] def rewriteTableNode(ctx: ExecuteContext)(tir: TableIR): Option[TableIR] =
tableRules(ctx).lift(tir)
Expand Down Expand Up @@ -218,7 +219,7 @@ object Simplify {
).reduce((f, g) => ir => f(ir).orElse(g(ir)))
}

private[this] def valueRules: PartialFunction[IR, IR] = {
private[this] def valueRules(ctx: ExecuteContext): PartialFunction[IR, IR] = {
// propagate NA
case x: IR if hasMissingStrictChild(x) =>
NA(x.typ)
Expand Down Expand Up @@ -456,7 +457,7 @@ object Simplify {
val rw = fieldNames.foldLeft[IR](Let(name, old, rewrite(body))) { case (comb, fieldName) =>
Let(newFieldRefs(fieldName).name, newFieldMap(fieldName), comb)
}
ForwardLets[IR](rw)
ForwardLets(ctx)(rw)

case SelectFields(old, fields) if tcoerce[TStruct](old.typ).fieldNames sameElements fields =>
old
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ case object SemanticHash extends Logging {
// Running the algorithm on the name-normalised IR
// removes sensitivity to compiler-generated names
val nameNormalizedIR = ctx.timer.time("NormalizeNames") {
new NormalizeNames(_.toString, allowFreeVariables = true)(root)
new NormalizeNames(_.toString, allowFreeVariables = true)(ctx, root)
}

val semhash = ctx.timer.time("Hash") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object LowerAndExecuteShuffles {
StreamBufferedAggregate(Ref(streamName, streamTyp), bindIR(GetField(insGlob, "__initState")) { states =>
Begin(aggSigs.indices.map { aIdx => InitFromSerializedValue(aIdx, GetTupleElement(states, aIdx), aggSigs(aIdx).state) })
}, newKey, seq, "row", aggSigs, bufferSize)),
0, 0).noSharing
0, 0).noSharing(ctx)


val analyses = LoweringAnalyses(partiallyAggregated, ctx)
Expand Down
11 changes: 8 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ import is.hail.utils._

final case class IrMetadata(semhash: Option[SemanticHash.Type]) {
private[this] var hashCounter: Int = 0
private[this] var markCounter: Int = 0
def nextHash: Option[SemanticHash.Type] = {
hashCounter += 1
semhash.map(SemanticHash.extend(_, SemanticHash.Bytes.fromInt(hashCounter)))
}

def nextFlag: Int = {
markCounter += 1
markCounter
}
}

trait LoweringPass {
Expand Down Expand Up @@ -102,7 +107,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass {
val context: String = "LowerArrayAggsToRunAggs"

def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = {
val x = ir.noSharing
val x = ir.noSharing(ctx)
val r = Requiredness(x, ctx)
RewriteBottomUp(x, {
case x@StreamAgg(a, name, query) =>
Expand All @@ -126,7 +131,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass {

if (newNode.typ != x.typ)
throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}")
Some(newNode.noSharing)
Some(newNode.noSharing(ctx))
case x@StreamAggScan(a, name, query) =>
val res = genUID()
val aggs = Extract(query, res, r, isScan=true)
Expand All @@ -142,7 +147,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass {
}
if (newNode.typ != x.typ)
throw new RuntimeException(s"types differ:\n new: ${ newNode.typ }\n old: ${ x.typ }")
Some(newNode.noSharing)
Some(newNode.noSharing(ctx))
case _ => None
})
}
Expand Down
13 changes: 8 additions & 5 deletions hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ sealed abstract class BaseTypeWithRequiredness {
throw new AssertionError(
s"children lengths differed ${children.length} ${newChildren.length}. ${children} ${newChildren} ${this}")
}
(children, newChildren).zipped.foreach { (r1, r2) =>
r1.unionFrom(r2)

// foreach on zipped seqs is very slow as the implementation
// doesn't know that the seqs are the same length.
for (i <- children.indices) {
children(i).unionFrom(newChildren(i))
}
}

Expand Down Expand Up @@ -500,12 +503,12 @@ object RTable {
RTable(rowStruct.fields.map(f => f.name -> f.typ), globStruct.fields.map(f => f.name -> f.typ), key)
}

def fromTableStage(ec: ExecuteContext, s: TableStage): RTable = {
def fromTableStage(ctx: ExecuteContext, s: TableStage): RTable = {
def virtualTypeWithReq(ir: IR, inputs: Env[PType]): VirtualTypeWithReq = {
import is.hail.expr.ir.Requiredness
val ns = ir.noSharing
val ns = ir.noSharing(ctx)
val usesAndDefs = ComputeUsesAndDefs(ns, errorIfFreeVariables = false)
val req = Requiredness.apply(ns, usesAndDefs, ec, inputs)
val req = Requiredness.apply(ns, usesAndDefs, ctx, inputs)
VirtualTypeWithReq(ir.typ, req.lookup(ns).asInstanceOf[TypeWithRequiredness])
}

Expand Down
Loading

0 comments on commit eaf4197

Please sign in to comment.