Skip to content

Commit

Permalink
improve 1164 (#1166)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Mar 10, 2024
1 parent 2b77f45 commit f813e1d
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 41 deletions.
59 changes: 56 additions & 3 deletions core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala
Expand Up @@ -974,9 +974,7 @@ object TypedExpr {
}

private def allPatternTypes[N](p: Pattern[N, Type]): SortedSet[Type] =
p.traverseType(t => Writer[SortedSet[Type], Type](SortedSet(t), t))
.run
._1
p.traverseType(t => Writer[SortedSet[Type], Type](SortedSet(t), t)).run._1

// Invariant, nel must have at least one item in common with quant.vars
private def filterQuant(
Expand Down Expand Up @@ -1390,6 +1388,16 @@ object TypedExpr {
if Type.quantify(q, tpe).sameAs(term.getType) =>
// we not uncommonly add an annotation just to make a generic wrapper to get back where
term
case Annotation(term, tpe)
if !q.vars.iterator
.map(_._1)
.exists(
Type.freeBoundTyVars(expr.getType :: Nil).toSet
) =>
// the variables may be free lower, but not here
val genTerm = normalizeQuantVars(q, term)
if (genTerm.getType.sameAs(tpe)) genTerm
else Annotation(normalizeQuantVars(q, term), tpe)
case _ =>
import Type.Quantification._
// We cannot rebind to any used typed inside of expr, but we can reuse
Expand Down Expand Up @@ -1478,4 +1486,49 @@ object TypedExpr {

implicit def typedExprHasRegion[T: HasRegion]: HasRegion[TypedExpr[T]] =
HasRegion.instance[TypedExpr[T]](e => HasRegion.region(e.tag))

// noshadow must include any free vars of args
def liftQuantification[A](
args: NonEmptyList[TypedExpr[A]],
noshadow: Set[Type.Var.Bound]
): (
Option[Type.Quantification],
NonEmptyList[TypedExpr[A]]
) = {

val htype = args.head.getType
val (oq, rest) = NonEmptyList.fromList(args.tail) match {
case Some(neTail) =>
val (oq, rest) = liftQuantification(neTail, noshadow)
(oq, rest.toList)
case None =>
(None, Nil)
}

htype match {
case Type.Quantified(q, rho) =>
oq match {
case Some(qtail) =>
// we have to unshadow with noshadow + all the vars in the tail
val (map, q1) =
q.unshadow(noshadow ++ qtail.vars.toList.iterator.map(_._1))
val rho1 = Type.substituteRhoVar(rho, map)
(
Some(q1.concat(qtail)),
NonEmptyList(TypedExpr.Annotation(args.head, rho1), rest)
)
case None =>
val (map, q1) = q.unshadow(noshadow)
val rho1 = Type.substituteRhoVar(rho, map)
(
Some(q1),
NonEmptyList(TypedExpr.Annotation(args.head, rho1), rest)
)
}

case _ =>
(oq, NonEmptyList(args.head, rest))
}
}

}
Expand Up @@ -327,6 +327,9 @@ object TypedExprNormalization {
}

f1 match {
// TODO: what if f1: Generic(_, AnnotatedLambda(_, _, _))
// we should still be able ton convert this to a let by
// instantiating to the right args
case AnnotatedLambda(lamArgs, expr, _) =>
// (y -> z)(x) = let y = x in z
val lets = lamArgs.zip(args).map { case ((n, ltpe), arg) =>
Expand Down
183 changes: 156 additions & 27 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Expand Up @@ -428,9 +428,7 @@ object Infer {

private val checkedKinds: Infer[Type => Option[Kind]] = {
val emptyRegion = Region(0, 0)
GetEnv.map { env =>
tpe => env.getKind(tpe, emptyRegion).toOption
}
GetEnv.map(env => tpe => env.getKind(tpe, emptyRegion).toOption)
}

// on t[a] we know t: k -> *, what is the variance
Expand Down Expand Up @@ -561,10 +559,10 @@ object Infer {
* with what they point to
*/
def zonkType(t: Type): Infer[Type] =
Type.zonkMeta(t)(zonk(_))
Type.zonkMeta(t)(zonk)

def zonkTypedExpr[A](e: TypedExpr[A]): Infer[TypedExpr[A]] =
TypedExpr.zonkMeta(e)(zonk(_))
TypedExpr.zonkMeta(e)(zonk)

val zonkTypeExprK
: FunctionK[TypedExpr.Rho, Lambda[x => Infer[TypedExpr[x]]]] =
Expand Down Expand Up @@ -1176,9 +1174,9 @@ object Infer {
left: Region,
right: Region
): Option[Infer[TypedExpr.Coerce]] =
inferred match {
(inferred match {
case Type.ForAll(vars, inT) =>
Type.instantiate(vars.iterator.toMap, inT, declared).map {
Type.instantiate(vars.iterator.toMap, inT, declared, Map.empty).map {
case (_, subs) =>
validateSubs(subs.toList, left, right)
.as {
Expand All @@ -1191,8 +1189,30 @@ object Infer {
}
}
}
case _ => None
}
case _ =>
None
}).orElse(declared match {
case Type.Exists(vars, inT) =>
Type.instantiate(vars.iterator.toMap, inT, inferred, Map.empty).map {
case (_, subs) =>
validateSubs(subs.toList, left, right)
.as {
new FunctionK[TypedExpr, TypedExpr] {
def apply[A](te: TypedExpr[A]): TypedExpr[A] =
// we apply the annotation here and let Normalization
// instantiate. We could explicitly have
// instantiation TypedExpr where you pass the variables to set
TypedExpr.Annotation(te, declared)
}
}
}
case _ =>
// TODO: we should be able to handle Dual quantification which could
// solve more cases. The challenge is existentials and universals appear
// on different sides, so cases where both need solutions can't be done
// with the current method that only solves one direction now.
None
})
// note, this is identical to subsCheckRho when declared is a Rho type
def subsCheck(
inferred: Type,
Expand Down Expand Up @@ -1377,8 +1397,9 @@ object Infer {
if inT.length == args.length =>
// see if we can instantiate the result type
// if we can, we use that to fix the known parameters and continue
Type.instantiate(univ.iterator.toMap, outT, tpe).flatMap {
case (frees, inst) =>
Type
.instantiate(univ.iterator.toMap, outT, tpe, Map.empty)
.flatMap { case (frees, inst) =>
// if instantiate works, we know outT => tpe
if (inst.nonEmpty && frees.isEmpty) {
// we made some progress and there are no frees
Expand All @@ -1389,7 +1410,7 @@ object Infer {
// We learned nothing
None
}
}
}
case _ =>
None
}
Expand Down Expand Up @@ -1462,6 +1483,86 @@ object Infer {
}
}

def applyViaInst[A: HasRegion](
fn: Expr[A],
args: NonEmptyList[Expr[A]],
tag: A
): Infer[Option[TypedExpr[A]]] =
(maybeSimple(fn), args.traverse(maybeSimple(_))).mapN {
(infFn, infArgs) =>
infFn.flatMap { fnTe =>
fnTe.getType match {
case Type.Fun.SimpleUniversal(us, argsT, resT)
if argsT.length == args.length =>
infArgs.sequence
.flatMap { argsTE =>
val argTypes = argsTE.map(_.getType)
// we can lift any quantification of the args
// outside of the function application
// We have to lift *before* substitution
val noshadows =
Type.freeBoundTyVars(resT :: argTypes.toList).toSet ++
us.iterator.map(_._1)
val (optQ, liftArgs) =
TypedExpr.liftQuantification(argsTE, noshadows)

val liftArgTypes = liftArgs.map(_.getType)
Type.instantiate(
us.toList.toMap,
Type.Tuple(argsT.toList),
Type.Tuple(liftArgTypes.toList),
optQ.fold(Map.empty[Type.Var.Bound, Kind])(
_.vars.toList.toMap
)
) match {
case None =>
/*
println(s"can't instantiate: ${
Type.fullyResolvedDocument.document(fnTe.getType).render(80)
} to ${liftArgTypes.map(Type.fullyResolvedDocument.document(_).render(80))}")
*/
pureNone
case Some((frees, inst)) =>
if (frees.nonEmpty) {
// TODO maybe we could handle this, but not yet
// seems like if the free vars are set to the same
// variable, then we can just lift it into the
// quantification
/*
println(s"remaining frees in ${
Type.fullyResolvedDocument.document(fnTe.getType).render(80)
} to ${liftArgTypes.map(Type.fullyResolvedDocument.document(_).render(80))}: $frees")
*/
pureNone
} else {
val subMap =
inst.view.mapValues(_._2).toMap[Type.Var, Type]
val fnType0 = Type.Fun(liftArgTypes, resT)
val fnType1 = Type.substituteVar(fnType0, subMap)
val resType = Type.substituteVar(resT, subMap)

val resTe = TypedExpr.App(
TypedExpr.Annotation(fnTe, fnType1),
liftArgs,
resType,
tag
)

val maybeQuant = optQ match {
case Some(q) => TypedExpr.Generic(q, resTe)
case None => resTe
}

pure(Some(maybeQuant))
}
}
}
case _ =>
pureNone
}
}
}.flatSequence

def applyRhoExpect[A: HasRegion](
fn: Expr[A],
args: NonEmptyList[Expr[A]],
Expand Down Expand Up @@ -1518,20 +1619,44 @@ object Infer {
res <- zonkTypedExpr(TypedExpr.Global(pack, name, vSigma, tag))
} yield coerce(res)
case Annotation(App(fn, args, tag), resT, annTag) =>
(
checkApply(fn, args, tag, resT, region(annTag)),
instSigma(resT, expect, region(annTag))
)
.parFlatMapN { (typedTerm, coerce) =>
zonkTypedExpr(typedTerm).map(coerce(_))
applyViaInst(fn, args, tag)
.flatMap {
case Some(te) =>
for {
co1 <- subsCheck(
te.getType,
resT,
region(tag),
region(annTag)
)
co2 <- instSigma(resT, expect, region(annTag))
z <- zonkTypedExpr(te)
} yield co2(co1(z))
case None =>
(
checkApply(fn, args, tag, resT, region(annTag)),
instSigma(resT, expect, region(annTag))
)
.parFlatMapN { (typedTerm, coerce) =>
zonkTypedExpr(typedTerm).map(coerce(_))
}
}
case App(fn, args, tag) =>
expect match {
case Expected.Check((rho, reg)) =>
checkApply(fn, args, tag, rho, reg)
case inf =>
applyRhoExpect(fn, args, tag, inf)
}
applyViaInst(fn, args, tag)
.flatMap {
case Some(te) =>
for {
co <- instSigma(te.getType, expect, HasRegion.region(tag))
z <- zonkTypedExpr(te)
} yield co(z)
case None =>
expect match {
case Expected.Check((rho, reg)) =>
checkApply(fn, args, tag, rho, reg)
case inf @ Expected.Inf(_) =>
applyRhoExpect(fn, args, tag, inf)
}
}
case Generic(tpes, in) =>
for {
unSkol <- inferForAll(tpes, in)
Expand Down Expand Up @@ -1697,7 +1822,12 @@ object Infer {
simp.flatMap { te =>
te.getType match {
case Type.ForAll(fas, in) =>
Type.instantiate(fas.iterator.toMap, in, rho) match {
Type.instantiate(
fas.iterator.toMap,
in,
rho,
Map.empty
) match {
case Some((frees, subs)) if frees.isEmpty =>
// we know that substituting in gives rho
// check kinds
Expand Down Expand Up @@ -2249,8 +2379,7 @@ object Infer {

for {
_ <- init
rhoT <- inferRho(e)
(rho, expTyRho) = rhoT
(rho, expTyRho) <- inferRho(e)
q <- quantify(unifySelf(expTyRho), rho)
} yield q
}
Expand Down

0 comments on commit f813e1d

Please sign in to comment.