Skip to content

Commit

Permalink
simplify NormalizeNames
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed May 14, 2024
1 parent 1e6bcab commit c8b6f3b
Show file tree
Hide file tree
Showing 15 changed files with 527 additions and 232 deletions.
49 changes: 38 additions & 11 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ abstract class BaseIR {

def children: Iterable[BaseIR] = childrenSeq

protected def copy(newChildren: IndexedSeq[BaseIR]): BaseIR
protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BaseIR

def deepCopy(): this.type =
copy(newChildren = childrenSeq.map(_.deepCopy())).asInstanceOf[this.type]
copyWithNewChildren(newChildren = childrenSeq.map(_.deepCopy())).asInstanceOf[this.type]

def noSharing(ctx: ExecuteContext): this.type =
if (HasIRSharing(ctx)(this)) this.deepCopy() else this
Expand All @@ -33,23 +33,23 @@ abstract class BaseIR {
if (childrenSeq.elementsSameObjects(newChildren))
this
else
copy(newChildren)
copyWithNewChildren(newChildren)
}

def mapChildren(f: (BaseIR) => BaseIR): BaseIR = {
val newChildren = childrenSeq.map(f)
if (childrenSeq.elementsSameObjects(newChildren))
this
else
copy(newChildren)
copyWithNewChildren(newChildren)
}

def mapChildrenWithIndexStackSafe(f: (BaseIR, Int) => StackFrame[BaseIR]): StackFrame[BaseIR] = {
call(childrenSeq.iterator.zipWithIndex.map(f.tupled).collectRecur).map { newChildren =>
if (childrenSeq.elementsSameObjects(newChildren))
this
else
copy(newChildren)
copyWithNewChildren(newChildren)
}
}

Expand All @@ -58,7 +58,7 @@ abstract class BaseIR {
if (childrenSeq.elementsSameObjects(newChildren))
this
else
copy(newChildren)
copyWithNewChildren(newChildren)
}
}

Expand All @@ -68,22 +68,49 @@ abstract class BaseIR {
f(child, childEnv)
}

def mapChildrenWithEnv[E <: GenericBindingEnv[E, Type]](env: E)(f: (BaseIR, E) => BaseIR)
: BaseIR = {
val newChildren = childrenSeq.toArray
def mapChildrenWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => BaseIR): BaseIR =
mapChildrenWithEnv[BindingEnv[Type]](env, (env, bindings) => env.extend(bindings))(f)

def mapChildrenWithEnv[E](
env: E,
update: (E, Bindings[Type]) => E,
)(
f: (BaseIR, E) => BaseIR
): BaseIR = {
val newChildren = Array(childrenSeq: _*)
var res = this
for (i <- newChildren.indices) {
val childEnv = env.extend(Bindings.get(res, i))
val childEnv = update(env, Bindings.get(res, i))
val child = newChildren(i)
val newChild = f(child, childEnv)
if (!(newChild eq child)) {
newChildren(i) = newChild
res = res.copy(newChildren)
res = res.copyWithNewChildren(newChildren)
}
}
res
}

def mapChildrenWithEnvStackSafe[E](
env: E,
update: (E, Bindings[Type]) => E,
)(
f: (BaseIR, E) => StackFrame[BaseIR]
): StackFrame[BaseIR] = {
val newChildren = Array(childrenSeq: _*)
var res = this
newChildren.indices.foreachRecur { i =>
val childEnv = update(env, Bindings.get(res, i))
val child = newChildren(i)
f(child, childEnv).map { newChild =>
if (!(newChild eq child)) {
newChildren(i) = newChild
res = res.copyWithNewChildren(newChildren)
}
}
}.map(_ => res)
}

def forEachChildWithEnvStackSafe[E <: GenericBindingEnv[E, Type]](
env: E
)(
Expand Down
29 changes: 11 additions & 18 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ object Bindings {
Bindings(FastSeq(name -> elementType(stream.typ)))
case RunAggScan(a, name, _, _, _, _) if i == 2 || i == 3 =>
Bindings(FastSeq(name -> elementType(a.typ)))
case StreamAgg(a, name, _) if i == 1 =>
Bindings(
FastSeq(name -> elementType(a.typ)),
agg = AggEnv.Create(FastSeq(0)),
)
case StreamScan(a, zero, accumName, valueName, _) if i == 2 =>
Bindings(FastSeq(accumName -> zero.typ, valueName -> elementType(a.typ)))
case StreamAggScan(a, name, _) if i == 1 =>
Expand Down Expand Up @@ -382,33 +387,26 @@ object Bindings {
rName -> tcoerce[TNDArray](r.typ).elementType,
))
case CollectDistributedArray(contexts, globals, cname, gname, _, _, _, _) if i == 2 =>
Bindings(
Bindings.inFreshScope(
FastSeq(
cname -> elementType(contexts.typ),
gname -> globals.typ,
),
agg = AggEnv.Drop,
scan = AggEnv.Drop,
dropEval = true,
)
)
case TableAggregate(child, _) =>
if (i == 1)
Bindings(
Bindings.inFreshScope(
child.typ.rowBindings,
eval = TableType.globalBindings,
agg = AggEnv.Create(TableType.rowBindings),
scan = AggEnv.Drop,
dropEval = true,
agg = Some(TableType.rowBindings),
)
else Bindings(agg = AggEnv.Drop, scan = AggEnv.Drop, dropEval = true)
case MatrixAggregate(child, _) =>
if (i == 1)
Bindings(
Bindings.inFreshScope(
child.typ.entryBindings,
eval = MatrixType.globalBindings,
agg = AggEnv.Create(MatrixType.entryBindings),
scan = AggEnv.Drop,
dropEval = true,
agg = Some(MatrixType.entryBindings),
)
else Bindings(agg = AggEnv.Drop, scan = AggEnv.Drop, dropEval = true)
case ApplyAggOp(init, _, _) =>
Expand Down Expand Up @@ -439,11 +437,6 @@ object Bindings {
agg = if (isScan) AggEnv.NoOp else AggEnv.Bind(FastSeq(0)),
scan = if (!isScan) AggEnv.NoOp else AggEnv.Bind(FastSeq(0)),
)
case StreamAgg(a, name, _) if i == 1 =>
Bindings(
FastSeq(name -> elementType(a.typ)),
agg = AggEnv.Create(FastSeq(0)),
)
case RelationalLet(name, value, _) =>
if (i == 1)
Bindings(
Expand Down
30 changes: 16 additions & 14 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ sealed abstract class BlockMatrixIR extends BaseIR {
protected[ir] def execute(ctx: ExecuteContext): BlockMatrix =
fatal("tried to execute unexecutable IR:\n" + Pretty(ctx, this))

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR

def blockCostIsLinear: Boolean

Expand All @@ -82,7 +82,7 @@ case class BlockMatrixRead(reader: BlockMatrixReader) extends BlockMatrixIR {

lazy val childrenSeq: IndexedSeq[BaseIR] = Array.empty[BlockMatrixIR]

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixRead = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixRead = {
assert(newChildren.isEmpty)
BlockMatrixRead(reader)
}
Expand Down Expand Up @@ -302,7 +302,7 @@ case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDen

lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, f)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixMap = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixMap = {
val IndexedSeq(newChild: BlockMatrixIR, newF: IR) = newChildren
BlockMatrixMap(newChild, eltName, newF, needsDense)
}
Expand Down Expand Up @@ -450,7 +450,7 @@ case class BlockMatrixMap2(

val blockCostIsLinear: Boolean = left.blockCostIsLinear && right.blockCostIsLinear

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixMap2 = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixMap2 = {
assert(newChildren.length == 3)
BlockMatrixMap2(
newChildren(0).asInstanceOf[BlockMatrixIR],
Expand Down Expand Up @@ -581,7 +581,7 @@ case class BlockMatrixDot(left: BlockMatrixIR, right: BlockMatrixIR) extends Blo

lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixDot = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixDot = {
assert(newChildren.length == 2)
BlockMatrixDot(
newChildren(0).asInstanceOf[BlockMatrixIR],
Expand Down Expand Up @@ -680,7 +680,8 @@ case class BlockMatrixBroadcast(

lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixBroadcast = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR])
: BlockMatrixBroadcast = {
assert(newChildren.length == 1)
BlockMatrixBroadcast(newChildren(0).asInstanceOf[BlockMatrixIR], inIndexExpr, shape, blockSize)
}
Expand Down Expand Up @@ -760,7 +761,7 @@ case class BlockMatrixAgg(

lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixAgg = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixAgg = {
assert(newChildren.length == 1)
BlockMatrixAgg(newChildren(0).asInstanceOf[BlockMatrixIR], axesToSumOut)
}
Expand Down Expand Up @@ -815,7 +816,7 @@ case class BlockMatrixFilter(

override def childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixFilter = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixFilter = {
assert(newChildren.length == 1)
BlockMatrixFilter(newChildren(0).asInstanceOf[BlockMatrixIR], indices)
}
Expand Down Expand Up @@ -844,7 +845,7 @@ case class BlockMatrixDensify(child: BlockMatrixIR) extends BlockMatrixIR {

val childrenSeq: IndexedSeq[BaseIR] = FastSeq(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
val IndexedSeq(newChild: BlockMatrixIR) = newChildren
BlockMatrixDensify(newChild)
}
Expand Down Expand Up @@ -965,7 +966,7 @@ case class BlockMatrixSparsify(

val childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
val IndexedSeq(newChild: BlockMatrixIR) = newChildren
BlockMatrixSparsify(newChild, sparsifier)
}
Expand Down Expand Up @@ -1010,7 +1011,7 @@ case class BlockMatrixSlice(child: BlockMatrixIR, slices: IndexedSeq[IndexedSeq[

override def childrenSeq: IndexedSeq[BaseIR] = Array(child)

override def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
assert(newChildren.length == 1)
BlockMatrixSlice(newChildren(0).asInstanceOf[BlockMatrixIR], slices)
}
Expand Down Expand Up @@ -1062,7 +1063,8 @@ case class ValueToBlockMatrix(

lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): ValueToBlockMatrix = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR])
: ValueToBlockMatrix = {
assert(newChildren.length == 1)
ValueToBlockMatrix(newChildren(0).asInstanceOf[IR], shape, blockSize)
}
Expand Down Expand Up @@ -1108,7 +1110,7 @@ case class BlockMatrixRandom(

lazy val childrenSeq: IndexedSeq[BaseIR] = Array.empty[BaseIR]

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixRandom = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixRandom = {
assert(newChildren.isEmpty)
BlockMatrixRandom(staticUID, gaussian, shape, blockSize)
}
Expand All @@ -1125,7 +1127,7 @@ case class RelationalLetBlockMatrix(name: String, value: IR, body: BlockMatrixIR

val blockCostIsLinear: Boolean = body.blockCostIsLinear

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
val IndexedSeq(newValue: IR, newBody: BlockMatrixIR) = newChildren
RelationalLetBlockMatrix(name, newValue, newBody)
}
Expand Down
12 changes: 10 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ object Compile {
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = {

val normalizedBody =
new NormalizeNames(_.toString)(ctx, body, Env(params.map { case (n, _) => n -> n }: _*))
NormalizeNames(
ctx,
body,
BindingEnv(Env(params.map { case (n, _) => n -> n }: _*)),
)
val k =
CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F](k) {
Expand Down Expand Up @@ -107,7 +111,11 @@ object CompileWithAggregators {
(HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion),
) = {
val normalizedBody =
new NormalizeNames(_.toString)(ctx, body, Env(params.map { case (n, _) => n -> n }: _*))
NormalizeNames(
ctx,
body,
BindingEnv(Env(params.map { case (n, _) => n -> n }: _*)),
)
val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) {

Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import scala.collection.Set

object ForwardLets {
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
val ir1 = new NormalizeNames(_ => genUID(), allowFreeVariables = true)(ctx, ir0)
val ir1 = NormalizeNames(ctx, ir0, allowFreeVariables = true)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ir1)

Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ sealed trait IR extends BaseIR {
protected lazy val childrenSeq: IndexedSeq[BaseIR] =
Children(this)

override protected def copy(newChildren: IndexedSeq[BaseIR]): IR =
override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): IR =
Copy(this, newChildren)

override def mapChildren(f: BaseIR => BaseIR): IR = super.mapChildren(f).asInstanceOf[IR]
Expand Down

0 comments on commit c8b6f3b

Please sign in to comment.