Skip to content

Commit

Permalink
[compiler] simplify PruneDeadFields (#14509)
Browse files Browse the repository at this point in the history
Makes some further progress on simplifying the `PruneDeadFields` pass,
with the primary goal of decoupling it from the details of the binding
structure.

The primary change is to `memoizeValueIR`. Before, it passed in only the
requested type of the node, and returned and environment containing all
free variables and their requested types. Any bound variables would then
need to be removed, and the environments of all children then merged.
This low-level manipulation of environments made it closely tied to the
binding structure, essentially redundantly encoding everything in
`Binds.scala`.

Now we pass an environment down into the children, which maps variables
to a mutable state tracking the requested type. Each `Ref` node unions
the requested type at the reference with the state in the environment.
This lets us use the general environment infrastructure.

I didn't do an assertion directly comparing the old and new
implementations, as I've done with some other pass rewrites. But
`PruneDeadFields` has pretty good test coverage.
  • Loading branch information
patrick-schultz committed May 17, 2024
1 parent e1e42c0 commit 9822e7d
Show file tree
Hide file tree
Showing 12 changed files with 745 additions and 906 deletions.
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ abstract class BaseIR {

def children: Iterable[BaseIR] = childrenSeq

def getChild(idx: Int): BaseIR = childrenSeq(idx)

protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BaseIR

def deepCopy(): this.type =
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ object Bindings {
case StreamFold(a, zero, accumName, valueName, _) if i == 2 =>
Bindings(FastSeq(accumName -> zero.typ, valueName -> elementType(a.typ)))
case StreamFold2(a, accum, valueName, _, _) =>
if (i <= accum.length)
if (i < accum.length + 1)
Bindings.empty
else if (i < 2 * accum.length + 1)
Bindings(
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/DeprecatedIRBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ object DeprecatedIRBuilder {

def castRename(t: Type): IRProxy = (env: E) => CastRename(ir(env), t)

def insertFields(fields: (Symbol, IRProxy)*): IRProxy = insertFieldsList(fields)
def insertFields(fields: (Symbol, IRProxy)*): IRProxy = insertFieldsList(fields.toFastSeq)

def insertFieldsList(
fields: Seq[(Symbol, IRProxy)],
fields: IndexedSeq[(Symbol, IRProxy)],
ordering: Option[IndexedSeq[String]] = None,
): IRProxy = (env: E) =>
InsertFields(ir(env), fields.map { case (s, fir) => (s.name, fir(env)) }, ordering)
Expand Down
5 changes: 3 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -955,12 +955,13 @@ final case class MakeStruct(fields: IndexedSeq[(String, IR)]) extends IR
final case class SelectFields(old: IR, fields: IndexedSeq[String]) extends IR

object InsertFields {
def apply(old: IR, fields: Seq[(String, IR)]): InsertFields = InsertFields(old, fields, None)
def apply(old: IR, fields: IndexedSeq[(String, IR)]): InsertFields =
InsertFields(old, fields, None)
}

final case class InsertFields(
old: IR,
fields: Seq[(String, IR)],
fields: IndexedSeq[(String, IR)],
fieldOrder: Option[IndexedSeq[String]],
) extends TypedIR[TStruct]

Expand Down
16 changes: 8 additions & 8 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import is.hail.expr.JSONAnnotationImpex
import is.hail.expr.ir.Pretty.prettyBooleanLiteral
import is.hail.expr.ir.agg._
import is.hail.expr.ir.functions.RelationalFunctions
import is.hail.types.TableType
import is.hail.types.{MatrixType, TableType}
import is.hail.types.virtual.{TArray, TInterval, TStream, Type}
import is.hail.utils.{space => _, _}
import is.hail.utils.prettyPrint._
Expand Down Expand Up @@ -851,7 +851,7 @@ class Pretty(
case I32(i) => s"c$i"
case stream if stream.typ.isInstanceOf[TStream] => "s"
case table if table.typ.isInstanceOf[TableType] => "ht"
case mt if mt.typ.isInstanceOf[TableType] => "mt"
case mt if mt.typ.isInstanceOf[MatrixType] => "mt"
case _ => ""
}

Expand Down Expand Up @@ -904,12 +904,12 @@ class Pretty(
ir match {
case Ref(name, _) =>
val body =
blockBindings.lookupOption(name).getOrElse(uniqueify("%undefined_ref", Some(name)))
blockBindings.lookupOption(name).getOrElse("#" + uniqueify("undefined_ref", Some(name)))
concat(openBlock, group(nest(2, concat(line, body, line)), "}"))
case RelationalRef(name, _) =>
val body =
blockBindings.lookupOption(name).getOrElse(uniqueify(
"%undefined_relational_ref",
blockBindings.lookupOption(name).getOrElse("#" + uniqueify(
"undefined_relational_ref",
Some(name),
))
concat(openBlock, group(nest(2, concat(line, body, line)), "}"))
Expand Down Expand Up @@ -953,10 +953,10 @@ class Pretty(
} yield {
child match {
case Ref(name, _) =>
bindings.lookupOption(name).getOrElse(uniqueify("%undefined_ref", Some(name)))
bindings.lookupOption(name).getOrElse("#" + uniqueify("undefined_ref", Some(name)))
case RelationalRef(name, _) =>
bindings.lookupOption(name).getOrElse(uniqueify(
"%undefined_relational_ref",
bindings.lookupOption(name).getOrElse("#" + uniqueify(
"undefined_relational_ref",
Some(name),
))
case _ =>
Expand Down

0 comments on commit 9822e7d

Please sign in to comment.