Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] Minor Requiredness Performance Enchancements #13991

Merged
merged 7 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package is.hail.expr.ir

import is.hail.backend.ExecuteContext
import is.hail.types.BaseType
import is.hail.utils.StackSafe._
import is.hail.utils._
Expand All @@ -15,7 +16,10 @@ abstract class BaseIR {

def deepCopy(): this.type = copy(newChildren = childrenSeq.map(_.deepCopy())).asInstanceOf[this.type]

lazy val noSharing: this.type = if (HasIRSharing(this)) this.deepCopy() else this
def noSharing(ctx: ExecuteContext): this.type =
if (HasIRSharing(ctx)(this)) this.deepCopy() else this

var mark: Int = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand the use of this correctly?

For use as a boolean flag by IR passes. Each pass uses a different sentinel value to encode "true" (and anything else is false). As long as we maintain the global invariant that no two passes use the same sentinel value, this allows us to reuse this field across passes without ever having to initialize it at the start of a pass.

If this is accurate, could you add a comment here? The invariant is important: if anybody were to use this inconsistent with the above, it would break all other passes that use this field.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Thanks for explaining it in a way I couldn't haha! I'll add the comment :)


def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
Expand Down
18 changes: 9 additions & 9 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ object Compile {
print: Option[PrintWriter] = None
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = {

val normalizeNames = new NormalizeNames(_.toString)
val normalizedBody = normalizeNames(body,
Env(params.map { case (n, _) => n -> n }: _*))
val normalizedBody = new NormalizeNames(_.toString)(ctx, body,
Env(params.map { case (n, _) => n -> n }: _*)
)
val k = CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F](k) {

var ir = body
ir = Subst(ir, BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }))
ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing
ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(ctx, ir, BindingEnv.empty)

Expand Down Expand Up @@ -85,17 +85,17 @@ object CompileWithAggregators {
body: IR,
optimize: Boolean = true
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion)) = {
val normalizeNames = new NormalizeNames(_.toString)
val normalizedBody = normalizeNames(body,
Env(params.map { case (n, _) => n -> n }: _*))
val normalizedBody = new NormalizeNames(_.toString)(ctx, body,
Env(params.map { case (n, _) => n -> n }: _*)
)
val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) {

var ir = body
ir = Subst(ir, BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }))
ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing
ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(ctx, ir, BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })))

Expand Down Expand Up @@ -184,7 +184,7 @@ object CompileIterator {

val outerRegion = outerRegionField

val ir = LoweringPipeline.compileLowerer(true)(ctx, body).asInstanceOf[IR].noSharing
val ir = LoweringPipeline.compileLowerer(true)(ctx, body).asInstanceOf[IR].noSharing(ctx)
TypeCheck(ctx, ir)

var elementAddress: Settable[Long] = null
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object EmitContext {
def analyze(ctx: ExecuteContext, ir: IR, pTypeEnv: Env[PType] = Env.empty): EmitContext = {
ctx.timer.time("EmitContext.analyze") {
val usesAndDefs = ComputeUsesAndDefs(ir, errorIfFreeVariables = false)
val requiredness = Requiredness.apply(ir, usesAndDefs, null, pTypeEnv)
val requiredness = Requiredness(ir, usesAndDefs, ctx, pTypeEnv)
val inLoopCriticalPath = ControlFlowPreventsSplit(ir, ParentPointers(ir), usesAndDefs)
val methodSplits = ComputeMethodSplits(ctx, ir, inLoopCriticalPath)
new EmitContext(ctx, requiredness, usesAndDefs, methodSplits, inLoopCriticalPath, Memo.empty[Unit])
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package is.hail.expr.ir

import is.hail.utils._
import is.hail.backend.ExecuteContext

import scala.collection.mutable

object ForwardLets {
def apply[T <: BaseIR](ir0: T): T = {
val ir1 = new NormalizeNames(_ => genUID(), allowFreeVariables = true).apply(ir0)
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
val ir1 = new NormalizeNames(_ => genUID(), allowFreeVariables = true)(ctx, ir0)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ir1)

Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Interpret.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object Interpret {
apply(tir, ctx, optimize = true)

def apply(tir: TableIR, ctx: ExecuteContext, optimize: Boolean): TableValue = {
val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing
val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing(ctx)
lowered.analyzeAndExecute(ctx).asTableValue(ctx)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ object LowerOrInterpretNonCompilable {
}
}

rewrite(ir.noSharing, mutable.HashMap.empty)
rewrite(ir.noSharing(ctx), mutable.HashMap.empty)
}
}
10 changes: 7 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package is.hail.expr.ir

import is.hail.backend.ExecuteContext
import is.hail.utils.StackSafe._

class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = false) {
Expand All @@ -10,11 +11,14 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean =
normFunction(count)
}

def apply(ir: IR, env: Env[String]): IR = apply(ir.noSharing, BindingEnv(env))
def apply(ctx: ExecuteContext, ir: IR, env: Env[String]): IR =
normalizeIR(ir.noSharing(ctx), BindingEnv(env)).run().asInstanceOf[IR]

def apply(ir: IR, env: BindingEnv[String]): IR = normalizeIR(ir.noSharing, env).run().asInstanceOf[IR]
def apply(ctx: ExecuteContext, ir: IR, env: BindingEnv[String]): IR =
normalizeIR(ir.noSharing(ctx), env).run().asInstanceOf[IR]

def apply(ir: BaseIR): BaseIR = normalizeIR(ir.noSharing, BindingEnv(agg=Some(Env.empty), scan=Some(Env.empty))).run()
def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR =
normalizeIR(ir.noSharing(ctx), BindingEnv(agg=Some(Env.empty), scan=Some(Env.empty))).run()

private def normalizeIR(ir: BaseIR, env: BindingEnv[String], context: Array[String] = Array()): StackFrame[BaseIR] = {

Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Optimize.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ object Optimize {
last = ir
runOpt(FoldConstants(ctx, _), iter, "FoldConstants")
runOpt(ExtractIntervalFilters(ctx, _), iter, "ExtractIntervalFilters")
runOpt(normalizeNames(_), iter, "NormalizeNames")
runOpt(normalizeNames(ctx, _), iter, "NormalizeNames")
runOpt(Simplify(ctx, _), iter, "Simplify")
runOpt(ForwardLets(_), iter, "ForwardLets")
runOpt(ForwardLets(ctx), iter, "ForwardLets")
runOpt(ForwardRelationalLets(_), iter, "ForwardRelationalLets")
runOpt(PruneDeadFields(ctx, _), iter, "PruneDeadFields")

Expand Down
21 changes: 9 additions & 12 deletions hail/src/main/scala/is/hail/expr/ir/RefEquality.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package is.hail.expr.ir

import is.hail.backend.ExecuteContext

import scala.collection.mutable

object RefEquality {
Expand Down Expand Up @@ -61,19 +63,14 @@ class Memo[T] private(val m: mutable.HashMap[RefEquality[BaseIR], T]) {


object HasIRSharing {
def apply(ir: BaseIR): Boolean = {
val m = mutable.HashSet.empty[RefEquality[BaseIR]]

def recur(x: BaseIR): Boolean = {
val re = RefEquality(x)
if (m.contains(re))
true
else {
m.add(re)
x.children.exists(recur)
}
def apply(ctx: ExecuteContext)(ir: BaseIR): Boolean = {
val mark = ctx.irMetadata.nextFlag

for (node <- IRTraversal.levelOrder(ir)) {
if (node.mark == mark) return true
node.mark = mark
}

recur(ir)
false
}
}
30 changes: 26 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/Requiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
type State = Memo[BaseTypeWithRequiredness]
private val cache = Memo.empty[BaseTypeWithRequiredness]
private val dependents = Memo.empty[mutable.Set[RefEquality[BaseIR]]]
private val q = mutable.Set[RefEquality[BaseIR]]()
private[this] val q = new Queue(ctx.irMetadata.nextFlag)

private val defs = Memo.empty[IndexedSeq[BaseTypeWithRequiredness]]
private val states = Memo.empty[IndexedSeq[TypeWithRequiredness]]
Expand Down Expand Up @@ -90,8 +90,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {

def run(): Unit = {
while (q.nonEmpty) {
val node = q.head
q -= node
val node = q.pop()
if (analyze(node.t) && dependents.contains(node)) {
q ++= dependents.lookup(node)
}
Expand Down Expand Up @@ -615,7 +614,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
val eltType = tcoerce[RIterable](requiredness).elementType
eltType.unionFrom(lookup(joinF))
case StreamMultiMerge(as, _) =>
requiredness.union(as.forall(lookup(_).required))
requiredness.union(as.forall(lookup(_).required))
val elt = tcoerce[RStruct](tcoerce[RIterable](requiredness).elementType)
as.foreach { a =>
elt.unionFields(tcoerce[RStruct](tcoerce[RIterable](lookup(a)).elementType))
Expand Down Expand Up @@ -828,4 +827,27 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {

requiredness.probeChangedAndReset()
}


final class Queue(val markFlag: Int) {
private[this] val q = mutable.Queue[RefEquality[BaseIR]]()

def nonEmpty: Boolean =
q.nonEmpty

def pop(): RefEquality[BaseIR] = {
val n = q.dequeue()
n.t.mark = 0
n
}

def +=(re: RefEquality[BaseIR]): Unit =
if (re.t.mark != markFlag) {
re.t.mark = markFlag
q += re
}

def ++=(res: Iterable[RefEquality[BaseIR]]): Unit =
res.foreach(this += _)
}
}
13 changes: 7 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/Simplify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ object Simplify {
private[this] def simplifyValue(ctx: ExecuteContext): IR => IR =
visitNode(
Simplify(ctx, _),
rewriteValueNode,
simplifyValue(ctx))
rewriteValueNode(ctx),
simplifyValue(ctx)
)

private[this] def simplifyTable(ctx: ExecuteContext)(tir: TableIR): TableIR =
visitNode(
Expand All @@ -55,8 +56,8 @@ object Simplify {
)(bmir)
}

private[this] def rewriteValueNode(ir: IR): Option[IR] =
valueRules.lift(ir).orElse(numericRules(ir))
private[this] def rewriteValueNode(ctx: ExecuteContext)(ir: IR): Option[IR] =
valueRules(ctx).lift(ir).orElse(numericRules(ir))

private[this] def rewriteTableNode(ctx: ExecuteContext)(tir: TableIR): Option[TableIR] =
tableRules(ctx).lift(tir)
Expand Down Expand Up @@ -218,7 +219,7 @@ object Simplify {
).reduce((f, g) => ir => f(ir).orElse(g(ir)))
}

private[this] def valueRules: PartialFunction[IR, IR] = {
private[this] def valueRules(ctx: ExecuteContext): PartialFunction[IR, IR] = {
// propagate NA
case x: IR if hasMissingStrictChild(x) =>
NA(x.typ)
Expand Down Expand Up @@ -456,7 +457,7 @@ object Simplify {
val rw = fieldNames.foldLeft[IR](Let(name, old, rewrite(body))) { case (comb, fieldName) =>
Let(newFieldRefs(fieldName).name, newFieldMap(fieldName), comb)
}
ForwardLets[IR](rw)
ForwardLets(ctx)(rw)

case SelectFields(old, fields) if tcoerce[TStruct](old.typ).fieldNames sameElements fields =>
old
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ case object SemanticHash extends Logging {
// Running the algorithm on the name-normalised IR
// removes sensitivity to compiler-generated names
val nameNormalizedIR = ctx.timer.time("NormalizeNames") {
new NormalizeNames(_.toString, allowFreeVariables = true)(root)
new NormalizeNames(_.toString, allowFreeVariables = true)(ctx, root)
}

val semhash = ctx.timer.time("Hash") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object LowerAndExecuteShuffles {
StreamBufferedAggregate(Ref(streamName, streamTyp), bindIR(GetField(insGlob, "__initState")) { states =>
Begin(aggSigs.indices.map { aIdx => InitFromSerializedValue(aIdx, GetTupleElement(states, aIdx), aggSigs(aIdx).state) })
}, newKey, seq, "row", aggSigs, bufferSize)),
0, 0).noSharing
0, 0).noSharing(ctx)


val analyses = LoweringAnalyses(partiallyAggregated, ctx)
Expand Down
11 changes: 8 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ import is.hail.utils._

final case class IrMetadata(semhash: Option[SemanticHash.Type]) {
private[this] var hashCounter: Int = 0
private[this] var markCounter: Int = 0
def nextHash: Option[SemanticHash.Type] = {
hashCounter += 1
semhash.map(SemanticHash.extend(_, SemanticHash.Bytes.fromInt(hashCounter)))
}

def nextFlag: Int = {
markCounter += 1
markCounter
}
}

trait LoweringPass {
Expand Down Expand Up @@ -102,7 +107,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass {
val context: String = "LowerArrayAggsToRunAggs"

def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = {
val x = ir.noSharing
val x = ir.noSharing(ctx)
val r = Requiredness(x, ctx)
RewriteBottomUp(x, {
case x@StreamAgg(a, name, query) =>
Expand All @@ -126,7 +131,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass {

if (newNode.typ != x.typ)
throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}")
Some(newNode.noSharing)
Some(newNode.noSharing(ctx))
case x@StreamAggScan(a, name, query) =>
val res = genUID()
val aggs = Extract(query, res, r, isScan=true)
Expand All @@ -142,7 +147,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass {
}
if (newNode.typ != x.typ)
throw new RuntimeException(s"types differ:\n new: ${ newNode.typ }\n old: ${ x.typ }")
Some(newNode.noSharing)
Some(newNode.noSharing(ctx))
case _ => None
})
}
Expand Down
13 changes: 8 additions & 5 deletions hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ sealed abstract class BaseTypeWithRequiredness {
throw new AssertionError(
s"children lengths differed ${children.length} ${newChildren.length}. ${children} ${newChildren} ${this}")
}
(children, newChildren).zipped.foreach { (r1, r2) =>
r1.unionFrom(r2)

// foreach on zipped seqs is very slow as the implementation
// doesn't know that the seqs are the same length.
Comment on lines +102 to +103
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this comment is out of place here. We use both of these patterns all over the codebase, and we shouldn't comment on their relative performance every time. Perhaps we should start a doc of these kinds of scala performance gotchas that we can refer to, and reevaluate with future scala version changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your feedback.

A zip and foreach is more like the code I want to write and indeed what I will write when not in a hotspot. A comment will prevent my future self getting upset at whomever wrote this low level ugly crap and changing it back.

I don't think a doc would be that practical. Knowing me, I'd likely forget about it rather than consult it whenever I need to write simple code like a for-loop.

for (i <- children.indices) {
children(i).unionFrom(newChildren(i))
}
}

Expand Down Expand Up @@ -500,12 +503,12 @@ object RTable {
RTable(rowStruct.fields.map(f => f.name -> f.typ), globStruct.fields.map(f => f.name -> f.typ), key)
}

def fromTableStage(ec: ExecuteContext, s: TableStage): RTable = {
def fromTableStage(ctx: ExecuteContext, s: TableStage): RTable = {
def virtualTypeWithReq(ir: IR, inputs: Env[PType]): VirtualTypeWithReq = {
import is.hail.expr.ir.Requiredness
val ns = ir.noSharing
val ns = ir.noSharing(ctx)
val usesAndDefs = ComputeUsesAndDefs(ns, errorIfFreeVariables = false)
val req = Requiredness.apply(ns, usesAndDefs, ec, inputs)
val req = Requiredness.apply(ns, usesAndDefs, ctx, inputs)
VirtualTypeWithReq(ir.typ, req.lookup(ns).asInstanceOf[TypeWithRequiredness])
}

Expand Down
Loading