diff --git a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala index b3008b7a6c9..51611bbd13e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala @@ -1,5 +1,6 @@ package is.hail.expr.ir +import is.hail.backend.ExecuteContext import is.hail.types.BaseType import is.hail.types.virtual.Type import is.hail.utils.StackSafe._ @@ -16,7 +17,16 @@ abstract class BaseIR { def deepCopy(): this.type = copy(newChildren = childrenSeq.map(_.deepCopy())).asInstanceOf[this.type] - lazy val noSharing: this.type = if (HasIRSharing(this)) this.deepCopy() else this + def noSharing(ctx: ExecuteContext): this.type = + if (HasIRSharing(ctx)(this)) this.deepCopy() else this + + + // For use as a boolean flag by IR passes. Each pass uses a different sentinel value to encode + // "true" (and anything else is false). As long as we maintain the global invariant that no + // two passes use the same sentinel value, this allows us to reuse this field across passes + // without ever having to initialize it at the start of a pass. + // New sentinel values can be obtained by `nextFlag` on `IRMetadata`. + var mark: Int = 0 def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = { val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray diff --git a/hail/src/main/scala/is/hail/expr/ir/Compile.scala b/hail/src/main/scala/is/hail/expr/ir/Compile.scala index fb7d59ccfee..56e127af9b4 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Compile.scala @@ -32,9 +32,9 @@ object Compile { print: Option[PrintWriter] = None ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = { - val normalizeNames = new NormalizeNames(_.toString) - val normalizedBody = normalizeNames(body, - Env(params.map { case (n, _) => n -> n }: _*)) + val normalizedBody = new NormalizeNames(_.toString)(ctx, body, + Env(params.map { case (n, _) => n -> n }: _*) + ) val k = CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody) (ctx.backend.lookupOrCompileCachedFunction[F](k) { @@ -42,7 +42,7 @@ object Compile { ir = Subst(ir, BindingEnv(params .zipWithIndex .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) })) - ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing + ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) TypeCheck(ctx, ir, BindingEnv.empty) @@ -85,9 +85,9 @@ object CompileWithAggregators { body: IR, optimize: Boolean = true ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion)) = { - val normalizeNames = new NormalizeNames(_.toString) - val normalizedBody = normalizeNames(body, - Env(params.map { case (n, _) => n -> n }: _*)) + val normalizedBody = new NormalizeNames(_.toString)(ctx, body, + Env(params.map { case (n, _) => n -> n }: _*) + ) val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody) (ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) { @@ -95,7 +95,7 @@ object CompileWithAggregators { ir = Subst(ir, BindingEnv(params .zipWithIndex .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) })) - ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing + ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) TypeCheck(ctx, ir, BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType }))) @@ -184,7 +184,7 @@ object CompileIterator { val outerRegion = outerRegionField - val ir = LoweringPipeline.compileLowerer(true)(ctx, body).asInstanceOf[IR].noSharing + val ir = LoweringPipeline.compileLowerer(true)(ctx, body).asInstanceOf[IR].noSharing(ctx) TypeCheck(ctx, ir) var elementAddress: Settable[Long] = null diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index ecfc03297ef..cf9b9b006fe 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -30,7 +30,7 @@ object EmitContext { def analyze(ctx: ExecuteContext, ir: IR, pTypeEnv: Env[PType] = Env.empty): EmitContext = { ctx.timer.time("EmitContext.analyze") { val usesAndDefs = ComputeUsesAndDefs(ir, errorIfFreeVariables = false) - val requiredness = Requiredness.apply(ir, usesAndDefs, null, pTypeEnv) + val requiredness = Requiredness(ir, usesAndDefs, ctx, pTypeEnv) val inLoopCriticalPath = ControlFlowPreventsSplit(ir, ParentPointers(ir), usesAndDefs) val methodSplits = ComputeMethodSplits(ctx, ir, inLoopCriticalPath) new EmitContext(ctx, requiredness, usesAndDefs, methodSplits, inLoopCriticalPath, Memo.empty[Unit]) diff --git a/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala b/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala index 5c9b17232c4..1a5db502a32 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala @@ -1,12 +1,12 @@ package is.hail.expr.ir -import is.hail.utils._ +import is.hail.backend.ExecuteContext import scala.collection.mutable object ForwardLets { - def apply[T <: BaseIR](ir0: T): T = { - val ir1 = new NormalizeNames(_ => genUID(), allowFreeVariables = true).apply(ir0) + def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = { + val ir1 = new NormalizeNames(_ => genUID(), allowFreeVariables = true)(ctx, ir0) val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false) val nestingDepth = NestingDepth(ir1) diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala index e6f7127865e..4b90bd6702c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala @@ -24,7 +24,7 @@ object Interpret { apply(tir, ctx, optimize = true) def apply(tir: TableIR, ctx: ExecuteContext, optimize: Boolean): TableValue = { - val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing + val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing(ctx) lowered.analyzeAndExecute(ctx).asTableValue(ctx) } diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala b/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala index abbbc58ccd8..f36863d5f2c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala @@ -60,6 +60,6 @@ object LowerOrInterpretNonCompilable { } } - rewrite(ir.noSharing, mutable.HashMap.empty) + rewrite(ir.noSharing(ctx), mutable.HashMap.empty) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala index fab2531f5a6..740bf040597 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala @@ -1,5 +1,6 @@ package is.hail.expr.ir +import is.hail.backend.ExecuteContext import is.hail.utils.StackSafe._ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = false) { @@ -10,11 +11,14 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = normFunction(count) } - def apply(ir: IR, env: Env[String]): IR = apply(ir.noSharing, BindingEnv(env)) + def apply(ctx: ExecuteContext, ir: IR, env: Env[String]): IR = + normalizeIR(ir.noSharing(ctx), BindingEnv(env)).run().asInstanceOf[IR] - def apply(ir: IR, env: BindingEnv[String]): IR = normalizeIR(ir.noSharing, env).run().asInstanceOf[IR] + def apply(ctx: ExecuteContext, ir: IR, env: BindingEnv[String]): IR = + normalizeIR(ir.noSharing(ctx), env).run().asInstanceOf[IR] - def apply(ir: BaseIR): BaseIR = normalizeIR(ir.noSharing, BindingEnv(agg=Some(Env.empty), scan=Some(Env.empty))).run() + def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = + normalizeIR(ir.noSharing(ctx), BindingEnv(agg=Some(Env.empty), scan=Some(Env.empty))).run() private def normalizeIR(ir: BaseIR, env: BindingEnv[String], context: Array[String] = Array()): StackFrame[BaseIR] = { diff --git a/hail/src/main/scala/is/hail/expr/ir/Optimize.scala b/hail/src/main/scala/is/hail/expr/ir/Optimize.scala index b8fff2e726d..45061ba1131 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Optimize.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Optimize.scala @@ -21,9 +21,9 @@ object Optimize { last = ir runOpt(FoldConstants(ctx, _), iter, "FoldConstants") runOpt(ExtractIntervalFilters(ctx, _), iter, "ExtractIntervalFilters") - runOpt(normalizeNames(_), iter, "NormalizeNames") + runOpt(normalizeNames(ctx, _), iter, "NormalizeNames") runOpt(Simplify(ctx, _), iter, "Simplify") - runOpt(ForwardLets(_), iter, "ForwardLets") + runOpt(ForwardLets(ctx), iter, "ForwardLets") runOpt(ForwardRelationalLets(_), iter, "ForwardRelationalLets") runOpt(PruneDeadFields(ctx, _), iter, "PruneDeadFields") diff --git a/hail/src/main/scala/is/hail/expr/ir/RefEquality.scala b/hail/src/main/scala/is/hail/expr/ir/RefEquality.scala index 7fa83bb96da..4a0e0acf5fe 100644 --- a/hail/src/main/scala/is/hail/expr/ir/RefEquality.scala +++ b/hail/src/main/scala/is/hail/expr/ir/RefEquality.scala @@ -1,5 +1,7 @@ package is.hail.expr.ir +import is.hail.backend.ExecuteContext + import scala.collection.mutable object RefEquality { @@ -61,19 +63,14 @@ class Memo[T] private(val m: mutable.HashMap[RefEquality[BaseIR], T]) { object HasIRSharing { - def apply(ir: BaseIR): Boolean = { - val m = mutable.HashSet.empty[RefEquality[BaseIR]] - - def recur(x: BaseIR): Boolean = { - val re = RefEquality(x) - if (m.contains(re)) - true - else { - m.add(re) - x.children.exists(recur) - } + def apply(ctx: ExecuteContext)(ir: BaseIR): Boolean = { + val mark = ctx.irMetadata.nextFlag + + for (node <- IRTraversal.levelOrder(ir)) { + if (node.mark == mark) return true + node.mark = mark } - recur(ir) + false } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala index e2c6b0c317e..59f9ebbf519 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala @@ -35,7 +35,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { type State = Memo[BaseTypeWithRequiredness] private val cache = Memo.empty[BaseTypeWithRequiredness] private val dependents = Memo.empty[mutable.Set[RefEquality[BaseIR]]] - private val q = mutable.Set[RefEquality[BaseIR]]() + private[this] val q = new Queue(ctx.irMetadata.nextFlag) private val defs = Memo.empty[IndexedSeq[BaseTypeWithRequiredness]] private val states = Memo.empty[IndexedSeq[TypeWithRequiredness]] @@ -90,8 +90,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { def run(): Unit = { while (q.nonEmpty) { - val node = q.head - q -= node + val node = q.pop() if (analyze(node.t) && dependents.contains(node)) { q ++= dependents.lookup(node) } @@ -615,7 +614,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { val eltType = tcoerce[RIterable](requiredness).elementType eltType.unionFrom(lookup(joinF)) case StreamMultiMerge(as, _) => - requiredness.union(as.forall(lookup(_).required)) + requiredness.union(as.forall(lookup(_).required)) val elt = tcoerce[RStruct](tcoerce[RIterable](requiredness).elementType) as.foreach { a => elt.unionFields(tcoerce[RStruct](tcoerce[RIterable](lookup(a)).elementType)) @@ -828,4 +827,27 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.probeChangedAndReset() } + + + final class Queue(val markFlag: Int) { + private[this] val q = mutable.Queue[RefEquality[BaseIR]]() + + def nonEmpty: Boolean = + q.nonEmpty + + def pop(): RefEquality[BaseIR] = { + val n = q.dequeue() + n.t.mark = 0 + n + } + + def +=(re: RefEquality[BaseIR]): Unit = + if (re.t.mark != markFlag) { + re.t.mark = markFlag + q += re + } + + def ++=(res: Iterable[RefEquality[BaseIR]]): Unit = + res.foreach(this += _) + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala index a963f36b350..1f9d0d3cffd 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala @@ -30,8 +30,9 @@ object Simplify { private[this] def simplifyValue(ctx: ExecuteContext): IR => IR = visitNode( Simplify(ctx, _), - rewriteValueNode, - simplifyValue(ctx)) + rewriteValueNode(ctx), + simplifyValue(ctx) + ) private[this] def simplifyTable(ctx: ExecuteContext)(tir: TableIR): TableIR = visitNode( @@ -55,8 +56,8 @@ object Simplify { )(bmir) } - private[this] def rewriteValueNode(ir: IR): Option[IR] = - valueRules.lift(ir).orElse(numericRules(ir)) + private[this] def rewriteValueNode(ctx: ExecuteContext)(ir: IR): Option[IR] = + valueRules(ctx).lift(ir).orElse(numericRules(ir)) private[this] def rewriteTableNode(ctx: ExecuteContext)(tir: TableIR): Option[TableIR] = tableRules(ctx).lift(tir) @@ -218,7 +219,7 @@ object Simplify { ).reduce((f, g) => ir => f(ir).orElse(g(ir))) } - private[this] def valueRules: PartialFunction[IR, IR] = { + private[this] def valueRules(ctx: ExecuteContext): PartialFunction[IR, IR] = { // propagate NA case x: IR if hasMissingStrictChild(x) => NA(x.typ) @@ -456,7 +457,7 @@ object Simplify { val rw = fieldNames.foldLeft[IR](Let(name, old, rewrite(body))) { case (comb, fieldName) => Let(newFieldRefs(fieldName).name, newFieldMap(fieldName), comb) } - ForwardLets[IR](rw) + ForwardLets(ctx)(rw) case SelectFields(old, fields) if tcoerce[TStruct](old.typ).fieldNames sameElements fields => old diff --git a/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala b/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala index a012c87f560..73812b9860a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala +++ b/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala @@ -30,7 +30,7 @@ case object SemanticHash extends Logging { // Running the algorithm on the name-normalised IR // removes sensitivity to compiler-generated names val nameNormalizedIR = ctx.timer.time("NormalizeNames") { - new NormalizeNames(_.toString, allowFreeVariables = true)(root) + new NormalizeNames(_.toString, allowFreeVariables = true)(ctx, root) } val semhash = ctx.timer.time("Hash") { diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala index 78419622694..5ea8fb96ef8 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala @@ -53,7 +53,7 @@ object LowerAndExecuteShuffles { StreamBufferedAggregate(Ref(streamName, streamTyp), bindIR(GetField(insGlob, "__initState")) { states => Begin(aggSigs.indices.map { aIdx => InitFromSerializedValue(aIdx, GetTupleElement(states, aIdx), aggSigs(aIdx).state) }) }, newKey, seq, "row", aggSigs, bufferSize)), - 0, 0).noSharing + 0, 0).noSharing(ctx) val analyses = LoweringAnalyses(partiallyAggregated, ctx) diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala index 772197a8e45..351ca79574a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala @@ -8,11 +8,16 @@ import is.hail.utils._ final case class IrMetadata(semhash: Option[SemanticHash.Type]) { private[this] var hashCounter: Int = 0 + private[this] var markCounter: Int = 0 def nextHash: Option[SemanticHash.Type] = { hashCounter += 1 semhash.map(SemanticHash.extend(_, SemanticHash.Bytes.fromInt(hashCounter))) } + def nextFlag: Int = { + markCounter += 1 + markCounter + } } trait LoweringPass { @@ -102,7 +107,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass { val context: String = "LowerArrayAggsToRunAggs" def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = { - val x = ir.noSharing + val x = ir.noSharing(ctx) val r = Requiredness(x, ctx) RewriteBottomUp(x, { case x@StreamAgg(a, name, query) => @@ -126,7 +131,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass { if (newNode.typ != x.typ) throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") - Some(newNode.noSharing) + Some(newNode.noSharing(ctx)) case x@StreamAggScan(a, name, query) => val res = genUID() val aggs = Extract(query, res, r, isScan=true) @@ -142,7 +147,7 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass { } if (newNode.typ != x.typ) throw new RuntimeException(s"types differ:\n new: ${ newNode.typ }\n old: ${ x.typ }") - Some(newNode.noSharing) + Some(newNode.noSharing(ctx)) case _ => None }) } diff --git a/hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala b/hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala index fef61cd7e2e..1bfa5aa0a44 100644 --- a/hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala +++ b/hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala @@ -98,8 +98,11 @@ sealed abstract class BaseTypeWithRequiredness { throw new AssertionError( s"children lengths differed ${children.length} ${newChildren.length}. ${children} ${newChildren} ${this}") } - (children, newChildren).zipped.foreach { (r1, r2) => - r1.unionFrom(r2) + + // foreach on zipped seqs is very slow as the implementation + // doesn't know that the seqs are the same length. + for (i <- children.indices) { + children(i).unionFrom(newChildren(i)) } } @@ -500,12 +503,12 @@ object RTable { RTable(rowStruct.fields.map(f => f.name -> f.typ), globStruct.fields.map(f => f.name -> f.typ), key) } - def fromTableStage(ec: ExecuteContext, s: TableStage): RTable = { + def fromTableStage(ctx: ExecuteContext, s: TableStage): RTable = { def virtualTypeWithReq(ir: IR, inputs: Env[PType]): VirtualTypeWithReq = { import is.hail.expr.ir.Requiredness - val ns = ir.noSharing + val ns = ir.noSharing(ctx) val usesAndDefs = ComputeUsesAndDefs(ns, errorIfFreeVariables = false) - val req = Requiredness.apply(ns, usesAndDefs, ec, inputs) + val req = Requiredness.apply(ns, usesAndDefs, ctx, inputs) VirtualTypeWithReq(ir.typ, req.lookup(ns).asInstanceOf[TypeWithRequiredness]) } diff --git a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala index c778465f0cd..35534bb6380 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala @@ -85,45 +85,45 @@ class ForwardLetsSuite extends HailSuite { @Test(dataProvider = "nonForwardingOps") def testNonForwardingOps(ir: IR): Unit = { - val after = ForwardLets(ir) - val normalizedBefore = (new NormalizeNames(_.toString)).apply(ir) - val normalizedAfter = (new NormalizeNames(_.toString)).apply(after) + val after = ForwardLets(ctx)(ir) + val normalizedBefore = (new NormalizeNames(_.toString))(ctx, ir) + val normalizedAfter = (new NormalizeNames(_.toString))(ctx, after) assert(normalizedBefore == normalizedAfter) } @Test(dataProvider = "nonForwardingNonEvalOps") def testNonForwardingNonEvalOps(ir: IR): Unit = { - val after = ForwardLets(ir) + val after = ForwardLets(ctx)(ir) assert(after.isInstanceOf[Let]) } @Test(dataProvider = "nonForwardingAggOps") def testNonForwardingAggOps(ir: IR): Unit = { - val after = ForwardLets(ir) + val after = ForwardLets(ctx)(ir) assert(after.isInstanceOf[AggLet]) } @Test(dataProvider = "forwardingOps") def testForwardingOps(ir: IR): Unit = { - val after = ForwardLets(ir) + val after = ForwardLets(ctx)(ir) assert(!after.isInstanceOf[Let]) assertEvalSame(ir, args = Array(5 -> TInt32)) } @Test(dataProvider = "forwardingAggOps") def testForwardingAggOps(ir: IR): Unit = { - val after = ForwardLets(ir) + val after = ForwardLets(ctx)(ir) assert(!after.isInstanceOf[AggLet]) } @Test def testLetNoMention(): Unit = { val ir = Let("x", I32(1), I32(2)) - assert(ForwardLets[IR](ir) == I32(2)) + assert(ForwardLets[IR](ctx)(ir) == I32(2)) } @Test def testLetRefRewrite(): Unit = { val ir = Let("x", I32(1), Ref("x", TInt32)) - assert(ForwardLets[IR](ir) == I32(1)) + assert(ForwardLets[IR](ctx)(ir) == I32(1)) } @Test def testAggregators(): Unit = { @@ -133,10 +133,7 @@ class ForwardLetsSuite extends HailSuite { })) .apply(aggEnv) - TypeCheck( - ctx, - ForwardLets(ir0).asInstanceOf[IR], - BindingEnv(Env.empty, agg = Some(aggEnv))) + TypeCheck(ctx, ForwardLets(ctx)(ir0), BindingEnv(Env.empty, agg = Some(aggEnv))) } @Test def testNestedBindingOverwrites(): Unit = { @@ -146,7 +143,7 @@ class ForwardLetsSuite extends HailSuite { }(env) TypeCheck(ctx, ir, BindingEnv(env)) - TypeCheck(ctx, ForwardLets(ir).asInstanceOf[IR], BindingEnv(env)) + TypeCheck(ctx, ForwardLets(ctx)(ir), BindingEnv(env)) } @Test def testLetsDoNotForwardInsideArrayAggWithNoOps(): Unit = { @@ -163,6 +160,6 @@ class ForwardLetsSuite extends HailSuite { ))) TypeCheck(ctx, x, BindingEnv(Env("y" -> TInt32))) - TypeCheck(ctx, ForwardLets(x).asInstanceOf[IR], BindingEnv(Env("y" -> TInt32))) + TypeCheck(ctx, ForwardLets(ctx)(x), BindingEnv(Env("y" -> TInt32))) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index f3d83d9507e..519e4f0258c 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -3387,8 +3387,8 @@ class IRSuite extends HailSuite { @Test def testHasIRSharing(): Unit = { val r = Ref("x", TInt32) val ir1 = MakeTuple.ordered(FastSeq(I64(1), r, r, I32(1))) - assert(HasIRSharing(ir1)) - assert(!HasIRSharing(ir1.deepCopy())) + assert(HasIRSharing(ctx)(ir1)) + assert(!HasIRSharing(ctx)(ir1.deepCopy())) } @Test def freeVariablesAggScanBindingEnv(): Unit = { diff --git a/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala b/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala index 87aa463d6ba..b3a7c5c6e56 100644 --- a/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala @@ -129,7 +129,7 @@ class SimplifySuite extends HailSuite { ) ) ) - val simplified = new NormalizeNames(_.toString, true).apply(Simplify(ctx, l)) + val simplified = new NormalizeNames(_.toString, true)(ctx, Simplify(ctx, l)) val expected = Let("1", I32(1) + Ref("OTHER_1", TInt32), Let("2", I32(1) + Ref("1", TInt32),