Skip to content

Commit

Permalink
simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Feb 6, 2024
1 parent e5ee4cd commit e6a3710
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 35 deletions.
66 changes: 36 additions & 30 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -820,14 +820,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
emitI(cond).consume(cb, {}, m => cb.if_(m.asBoolean.value, emitVoid(cnsq), emitVoid(altr)))

case let: Let =>
println(Pretty.sexprStyle(let))
val newEnv = emitLetBindings(
emitI = (ir, cb, env, r) =>
if (ir.typ.isInstanceOf[TStream])
EmitStream.produce(this, ir, cb, cb.emb, r, env, container)
else emitI(ir, cb = cb, env = env, region = r),
emitVoid = (ir, cb, env, r) => emitVoid(ir, env = env, region = r, cb = cb),
)(let, cb, env, region)
val newEnv = emitLetBindings(let, cb, env, region, container, loopEnv)
emitVoid(let.body, env = newEnv)

case StreamFor(a, valueName, body) =>
Expand Down Expand Up @@ -1111,14 +1104,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
presentPC(primitive(m))

case let: Let =>
val newEnv = emitLetBindings(
emitI = (ir, cb, env, r) =>
if (ir.typ.isInstanceOf[TStream]) // emitStream(ir, cb, region, env = env)
EmitStream.produce(this, ir, cb, cb.emb, r, env, container)
else emitInNewBuilder(cb, ir, region = r, env = env),
emitVoid = (ir, cb, env, r) =>
emitVoid(ir, cb = cb, env = env, region = r),
)(let, cb, env, region)
val newEnv = emitLetBindings(let, cb, env, region, container, loopEnv)
emitI(let.body, env = newEnv)

case Coalesce(values) =>
Expand Down Expand Up @@ -3625,21 +3611,35 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
cb.memoize(cb.invokeCode[Boolean](sort, cb.this_, region, l, r))
}

/** Emit the bindings (but not the body) of `let`. If possible, split bindings into chunks, and
* emit each chunk in a separate method.
*/
// TODO: splitting logic should get lifted into ComputeMethodSplits
def emitLetBindings(
emitI: (IR, EmitCodeBuilder, EmitEnv, Value[Region]) => IEmitCode,
emitVoid: (IR, EmitCodeBuilder, EmitEnv, Value[Region]) => Unit,
)(
let: Let,
cb: EmitCodeBuilder,
env: EmitEnv,
r: Value[Region],
container: Option[AggContainer],
loopEnv: Option[Env[LoopRef]],
): EmitEnv = {
def emitI(ir: IR, cb: EmitCodeBuilder, env: EmitEnv, r: Value[Region]): IEmitCode =
if (ir.typ.isInstanceOf[TStream])
EmitStream.produce(this, ir, cb, cb.emb, r, env, container)
else this.emitI(ir, cb, r, env, container, loopEnv)

def emitVoid(ir: IR, cb: EmitCodeBuilder, env: EmitEnv, r: Value[Region]): Unit =
this.emitVoid(cb, ir, r, env, container, loopEnv)

val uses: mutable.Set[String] =
ctx.usesAndDefs.uses.get(let) match {
case Some(refs) => refs.map(_.t.name)
case None => mutable.Set.empty
}

/* Emit a sequence of bindings into a code builder. Each is added to the environment of all
* following bindings. Any bindings which is unused and has no side effects is skipped (this is
* mostly an optimization, but it is important not to emit unused streams). */
def emitChunk(cb: EmitCodeBuilder, bindings: Seq[(String, IR)], env: EmitEnv, r: Value[Region])
: EmitEnv =
bindings.foldLeft(env) { case (newEnv, (name, ir)) =>
Expand All @@ -3655,6 +3655,9 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
}
}

/* Bindings before chunkStart have been emitted. Bindings in the range chunkStart <= i < pos are
* a pending chunk, which have not yet been emitted. chunkSize is the number of non-skipped
* bindings in the pending chunk. groupIdx is how many chunks have already been emitted. */
@tailrec def go(
env: EmitEnv,
chunkStart: Int,
Expand All @@ -3663,7 +3666,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
groupIdx: Int,
): EmitEnv = {

def emitChunkWrapped(): EmitEnv = {
def emitChunkInSeparateMethod(): EmitEnv = {
val mb = cb.emb.genEmitMethod(
s"begin_group_$groupIdx",
FastSeq[ParamType](classInfo[Region]),
Expand All @@ -3678,9 +3681,13 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
newEnv
}

def cantEmitInSeparateMethod(ir: IR): Boolean =
ir.typ.isInstanceOf[TStream] || ctx.inLoopCriticalPath.contains(ir)

// end of bindings, emit any pending chunk and return the final environment
if (pos == let.bindings.length) {
if (chunkSize > 0)
return emitChunkWrapped()
return emitChunkInSeparateMethod()
else
return env
}
Expand All @@ -3690,27 +3697,26 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
// skip over unused streams
if (curIR.typ.isInstanceOf[TStream] && !uses.contains(curName)) {
go(env, chunkStart, pos + 1, chunkSize, groupIdx)
} else if (chunkSize == 16 || (chunkSize > 0 && curIR.typ.isInstanceOf[TStream])) {
// emit the current chunk if it's either max size, or broken by a stream
val newEnv = emitChunkWrapped()
} else if (chunkSize == 16 || (chunkSize > 0 && cantEmitInSeparateMethod(curIR))) {
/* emit the current chunk if it's either max size, or broken by a stream or other control
* flow */
val newEnv = emitChunkInSeparateMethod()
go(newEnv, pos, pos, 0, groupIdx + 1)
} else if (curIR.typ.isInstanceOf[TStream]) {
// emit a stream, assuming we've already emitted any prior chunk
assert(chunkSize == 0) // no pending bindings
val value = emitI(curIR, cb, env, r)
val memo = cb.memoizeMaybeStreamValue(value, s"let_$curName")
val newEnv = env.bind(curName, memo)
go(newEnv, pos + 1, pos + 1, 0, groupIdx)
} else {
// add cur binding to pending group
// add cur binding to pending chunk
go(env, chunkStart, pos + 1, chunkSize + 1, groupIdx)
}
}

if (
let.bindings.size > 4 &&
!ctx.inLoopCriticalPath.contains(let) &&
let.bindings.forall(x => !ctx.inLoopCriticalPath.contains(x._2))
) {
// don't split into separate methods if the bindings list is small
if (let.bindings.size > 4) {
go(env, 0, 0, 0, 0)
} else {
emitChunk(cb, let.bindings, env, r)
Expand Down
9 changes: 4 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ import is.hail.utils._
import is.hail.variant.Locus

import java.util

import org.objectweb.asm.Opcodes._

import scala.annotation.nowarn

abstract class StreamProducer {

// method builder where this stream is valid
Expand Down Expand Up @@ -144,6 +145,7 @@ object EmitStream {
container: Option[AggContainer],
): IEmitCode = {

@nowarn("cat=unused-locals&msg=local default argument")
def emitVoid(
ir: IR,
cb: EmitCodeBuilder,
Expand Down Expand Up @@ -364,10 +366,7 @@ object EmitStream {
}

case let: Let =>
val newEnv = emitter.emitLetBindings(
emitI = (ir, cb, env, r) => emit(ir, cb, region = r, env = env),
emitVoid = (ir, cb, env, r) => emitVoid(ir, cb, region = r, env = env),
)(let, cb, env, outerRegion)
val newEnv = emitter.emitLetBindings(let, cb, env, outerRegion, container, None)
produce(let.body, cb, env = newEnv)

case In(n, _) =>
Expand Down

0 comments on commit e6a3710

Please sign in to comment.