Skip to content

Commit

Permalink
simplify PruneDeadFields
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Apr 26, 2024
1 parent e84931a commit 578e1c7
Show file tree
Hide file tree
Showing 12 changed files with 721 additions and 842 deletions.
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ abstract class BaseIR {

def children: Iterable[BaseIR] = childrenSeq

def getChild(idx: Int): BaseIR = childrenSeq(idx)

protected def withNewChildren(newChildren: IndexedSeq[BaseIR]): BaseIR

def deepCopy(): this.type =
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ object Bindings {
case StreamFold(a, zero, accumName, valueName, _) if i == 2 =>
Bindings(FastSeq(accumName -> zero.typ, valueName -> elementType(a.typ)))
case StreamFold2(a, accum, valueName, _, _) =>
if (i <= accum.length)
if (i < accum.length + 1)
Bindings.empty
else if (i < 2 * accum.length + 1)
Bindings(
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/DeprecatedIRBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ object DeprecatedIRBuilder {

def castRename(t: Type): IRProxy = (env: E) => CastRename(ir(env), t)

def insertFields(fields: (Symbol, IRProxy)*): IRProxy = insertFieldsList(fields)
def insertFields(fields: (Symbol, IRProxy)*): IRProxy = insertFieldsList(fields.toFastSeq)

def insertFieldsList(
fields: Seq[(Symbol, IRProxy)],
fields: IndexedSeq[(Symbol, IRProxy)],
ordering: Option[IndexedSeq[String]] = None,
): IRProxy = (env: E) =>
InsertFields(ir(env), fields.map { case (s, fir) => (s.name, fir(env)) }, ordering)
Expand Down
5 changes: 3 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -959,12 +959,13 @@ final case class MakeStruct(fields: IndexedSeq[(String, IR)]) extends IR
final case class SelectFields(old: IR, fields: IndexedSeq[String]) extends IR

object InsertFields {
def apply(old: IR, fields: Seq[(String, IR)]): InsertFields = InsertFields(old, fields, None)
def apply(old: IR, fields: IndexedSeq[(String, IR)]): InsertFields =
InsertFields(old, fields, None)
}

final case class InsertFields(
old: IR,
fields: Seq[(String, IR)],
fields: IndexedSeq[(String, IR)],
fieldOrder: Option[IndexedSeq[String]],
) extends TypedIR[TStruct]

Expand Down
16 changes: 8 additions & 8 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import is.hail.expr.JSONAnnotationImpex
import is.hail.expr.ir.Pretty.prettyBooleanLiteral
import is.hail.expr.ir.agg._
import is.hail.expr.ir.functions.RelationalFunctions
import is.hail.types.TableType
import is.hail.types.{MatrixType, TableType}
import is.hail.types.virtual.{TArray, TInterval, TStream, Type}
import is.hail.utils.{space => _, _}
import is.hail.utils.prettyPrint._
Expand Down Expand Up @@ -821,7 +821,7 @@ class Pretty(
case I32(i) => s"c$i"
case stream if stream.typ.isInstanceOf[TStream] => "s"
case table if table.typ.isInstanceOf[TableType] => "ht"
case mt if mt.typ.isInstanceOf[TableType] => "mt"
case mt if mt.typ.isInstanceOf[MatrixType] => "mt"
case _ => ""
}

Expand Down Expand Up @@ -874,12 +874,12 @@ class Pretty(
ir match {
case Ref(name, _) =>
val body =
blockBindings.lookupOption(name).getOrElse(uniqueify("%undefined_ref", Some(name)))
blockBindings.lookupOption(name).getOrElse("#" + uniqueify("undefined_ref", Some(name)))
concat(openBlock, group(nest(2, concat(line, body, line)), "}"))
case RelationalRef(name, _) =>
val body =
blockBindings.lookupOption(name).getOrElse(uniqueify(
"%undefined_relational_ref",
blockBindings.lookupOption(name).getOrElse("#" + uniqueify(
"undefined_relational_ref",
Some(name),
))
concat(openBlock, group(nest(2, concat(line, body, line)), "}"))
Expand Down Expand Up @@ -923,10 +923,10 @@ class Pretty(
} yield {
child match {
case Ref(name, _) =>
bindings.lookupOption(name).getOrElse(uniqueify("%undefined_ref", Some(name)))
bindings.lookupOption(name).getOrElse("#" + uniqueify("undefined_ref", Some(name)))
case RelationalRef(name, _) =>
bindings.lookupOption(name).getOrElse(uniqueify(
"%undefined_relational_ref",
bindings.lookupOption(name).getOrElse("#" + uniqueify(
"undefined_relational_ref",
Some(name),
))
case _ =>
Expand Down

0 comments on commit 578e1c7

Please sign in to comment.