From 10b6baa6ea1af241308132437371a6f4258f4acf Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Thu, 7 Mar 2024 22:16:14 -0700 Subject: [PATCH 1/5] try type instantitation for tuple-like cases --- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 64 +++++++++++++++++-- .../bykn/bosatsu/rankn/RankNInferTest.scala | 28 +++++++- test_workspace/TypeConstraint.bosatsu | 2 +- 3 files changed, 85 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 19e59fd4e..9eeda447b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -428,9 +428,7 @@ object Infer { private val checkedKinds: Infer[Type => Option[Kind]] = { val emptyRegion = Region(0, 0) - GetEnv.map { env => - tpe => env.getKind(tpe, emptyRegion).toOption - } + GetEnv.map(env => tpe => env.getKind(tpe, emptyRegion).toOption) } // on t[a] we know t: k -> *, what is the variance @@ -561,10 +559,10 @@ object Infer { * with what they point to */ def zonkType(t: Type): Infer[Type] = - Type.zonkMeta(t)(zonk(_)) + Type.zonkMeta(t)(zonk) def zonkTypedExpr[A](e: TypedExpr[A]): Infer[TypedExpr[A]] = - TypedExpr.zonkMeta(e)(zonk(_)) + TypedExpr.zonkMeta(e)(zonk) val zonkTypeExprK : FunctionK[TypedExpr.Rho, Lambda[x => Infer[TypedExpr[x]]]] = @@ -1529,8 +1527,60 @@ object Infer { expect match { case Expected.Check((rho, reg)) => checkApply(fn, args, tag, rho, reg) - case inf => - applyRhoExpect(fn, args, tag, inf) + case inf @ Expected.Inf(_) => + (maybeSimple(fn), args.traverse(maybeSimple(_))) + .mapN { (infFn, infArgs) => + infFn.flatMap { fnTe => + fnTe.getType match { + case Type.Fun.SimpleUniversal(us, argsT, resT) + if argsT.length == args.length => + infArgs.sequence + .flatMap { argsTE => + val argTypes = argsTE.map(_.getType) + Type.instantiate( + us.toList.toMap, + Type.Tuple(argsT.toList), + Type.Tuple(argTypes.toList) + ) match { + case None => + pureNone + case Some((frees, inst)) => + if (frees.nonEmpty) { + // TODO maybe we could handle this, but not yet + pureNone + } else { + val resType = Type.substituteVar( + resT, + inst.view.mapValues(_._2).toMap + ) + val resTe = TypedExpr.App( + fnTe, + argsTE, + resType, + term.tag + ) + + instSigma( + resType, + inf, + HasRegion.region(term) + ) + .map(co => Some(co(resTe))) + } + } + } + case _ => + pureNone + } + } + } + .flatSequence + .flatMap { + case Some(te) => pure(te) + case None => + applyRhoExpect(fn, args, tag, inf) + } + } case Generic(tpes, in) => for { diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index 1908973fe..0fd1d75e6 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -201,7 +201,7 @@ class RankNInferTest extends AnyFunSuite { ) assert(Type.metaTvs(tp :: Nil).isEmpty, s"illegal inferred type: $teStr") - assert(te.getType.sameAs(typeFrom(tpe)), s"found: ${te.repr}") + assert(te.getType.sameAs(typeFrom(tpe)), s"found: ${te.repr.render(80)}") } // this could be used to test the string representation of expressions @@ -1576,6 +1576,17 @@ x = hide(y) """# struct Tup(a, b) +def hide[b](x: b) -> exists a. a: x +x = hide(1) +y = hide("1") +z: Tup[exists a. a, exists b. b] = Tup(x, y) +""", + "Tup[exists a. a, exists b. b]" + ) + parseProgram( + """# +struct Tup(a, b) + def hide[b](x: b) -> exists a. a: x def makeTup[a, b](x: a, y: b) -> Tup[a, b]: Tup(x, y) x = hide(1) @@ -1835,4 +1846,19 @@ f3: Foo[forall a. a] = f1 "Foo[forall a. a]" ) } + + test("test Liskov example") { + parseProgram( + """ +struct Sub[a: -*, b: +*](sub: forall f: +* -> *. f[a] -> f[b]) +struct Tup(a, b) + +refl_sub: forall a. Sub[a, a] = Sub(x -> x) +refl_any: Sub[forall a. a, exists a. a] = refl_sub +#ignore: Tup[forall a. Sub[a, a], Sub[forall a. a, exists a. a]] = Tup(refl_sub, refl_any) +ignore = Tup(refl_sub, refl_any) +""", + "Tup[forall a. Sub[a, a], Sub[forall a. a, exists a. a]]" + ) + } } diff --git a/test_workspace/TypeConstraint.bosatsu b/test_workspace/TypeConstraint.bosatsu index a2621edbb..3dc7fa408 100644 --- a/test_workspace/TypeConstraint.bosatsu +++ b/test_workspace/TypeConstraint.bosatsu @@ -77,4 +77,4 @@ refl_bottom1: Sub[forall a. a, forall a. a] = refl_sub refl_Int: Sub[forall a. a, Int] = refl_sub refl_any: Sub[forall a. a, exists a. a] = refl_sub refl_any1: Sub[exists a. a, exists a. a] = refl_sub -refl_Int_any: Sub[Int, exists a. a] = refl_sub +refl_Int_any: Sub[Int, exists a. a] = refl_sub \ No newline at end of file From 743ae68f512f0090ecc2102adbb8cd3f9c3ca063 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Thu, 7 Mar 2024 22:34:38 -0700 Subject: [PATCH 2/5] Fix TypedExprTest --- .../org/bykn/bosatsu/TypedExprNormalization.scala | 3 +++ core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala | 10 +++++----- .../test/scala/org/bykn/bosatsu/TypedExprTest.scala | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index 25bd8b345..dc1d54c4e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -327,6 +327,9 @@ object TypedExprNormalization { } f1 match { + // TODO: what if f1: Generic(_, AnnotatedLambda(_, _, _)) + // we should still be able ton convert this to a let by + // instantiating to the right args case AnnotatedLambda(lamArgs, expr, _) => // (y -> z)(x) = let y = x in z val lets = lamArgs.zip(args).map { case ((n, ltpe), arg) => diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 9eeda447b..e89e1f938 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -1549,12 +1549,12 @@ object Infer { // TODO maybe we could handle this, but not yet pureNone } else { - val resType = Type.substituteVar( - resT, - inst.view.mapValues(_._2).toMap - ) + val subMap = inst.view.mapValues(_._2).toMap[Type.Var, Type] + val fnType0 = Type.Fun(argsT, resT) + val fnType1 = Type.substituteVar(fnType0, subMap) + val resType = Type.substituteVar(resT, subMap) val resTe = TypedExpr.App( - fnTe, + TypedExpr.Annotation(fnTe, fnType1), argsTE, resType, term.tag diff --git a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala index eed1f93ff..7ed0b68b7 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala @@ -605,7 +605,7 @@ res = ( checkLast(""" res = _ -> 1 """) { te2 => - assert(te1.void == te2.void, s"${te1.repr} != ${te2.repr}") + assert(te1.void == te2.void, s"${te1.repr.render(80)} != ${te2.repr.render(80)}") } } From 3fa3d601a6b4d1a6c37cbc999a2c6cf613474c50 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Fri, 8 Mar 2024 10:36:50 -0700 Subject: [PATCH 3/5] improve instantiation --- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 212 +++++++++++++----- .../scala/org/bykn/bosatsu/rankn/Type.scala | 62 ++++- .../bykn/bosatsu/rankn/RankNInferTest.scala | 2 +- .../org/bykn/bosatsu/rankn/TypeTest.scala | 6 +- test_workspace/TypeConstraint.bosatsu | 4 +- 5 files changed, 216 insertions(+), 70 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index e89e1f938..5e4d48b20 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -1176,7 +1176,7 @@ object Infer { ): Option[Infer[TypedExpr.Coerce]] = inferred match { case Type.ForAll(vars, inT) => - Type.instantiate(vars.iterator.toMap, inT, declared).map { + Type.instantiate(vars.iterator.toMap, inT, declared, Map.empty).map { case (_, subs) => validateSubs(subs.toList, left, right) .as { @@ -1375,8 +1375,9 @@ object Infer { if inT.length == args.length => // see if we can instantiate the result type // if we can, we use that to fix the known parameters and continue - Type.instantiate(univ.iterator.toMap, outT, tpe).flatMap { - case (frees, inst) => + Type + .instantiate(univ.iterator.toMap, outT, tpe, Map.empty) + .flatMap { case (frees, inst) => // if instantiate works, we know outT => tpe if (inst.nonEmpty && frees.isEmpty) { // we made some progress and there are no frees @@ -1387,7 +1388,7 @@ object Infer { // We learned nothing None } - } + } case _ => None } @@ -1460,6 +1461,137 @@ object Infer { } } + // noshadow must include any free vars of args + def liftQuantification[A]( + args: NonEmptyList[TypedExpr[A]], + noshadow: Set[Type.Var.Bound] + ): ( + Option[Type.Quantification], + NonEmptyList[TypedExpr[A]] + ) = { + + val htype = args.head.getType + val (oq, rest) = NonEmptyList.fromList(args.tail) match { + case Some(neTail) => + val (oq, rest) = liftQuantification(neTail, noshadow) + (oq, rest.toList) + case None => + (None, Nil) + } + + htype match { + case Type.Quantified(q, rho) => + oq match { + case Some(qtail) => + // we have to unshadow with noshadow + all the vars in the tail + val (map, q1) = + q.unshadow(noshadow ++ qtail.vars.toList.iterator.map(_._1)) + val rho1 = Type.substituteRhoVar(rho, map) + ( + Some(q1.concat(qtail)), + NonEmptyList(TypedExpr.Annotation(args.head, rho1), rest) + ) + case None => + val (map, q1) = q.unshadow(noshadow) + val rho1 = Type.substituteRhoVar(rho, map) + ( + Some(q1), + NonEmptyList(TypedExpr.Annotation(args.head, rho1), rest) + ) + } + + case _ => + (oq, NonEmptyList(args.head, rest)) + } + } + + def applyViaInst[A: HasRegion]( + fn: Expr[A], + args: NonEmptyList[Expr[A]], + tag: A, + region: Region, + exp: Expected[(Type.Rho, Region)] + ): Infer[Option[TypedExpr[A]]] = + (maybeSimple(fn), args.traverse(maybeSimple(_))).mapN { + (infFn, infArgs) => + infFn.flatMap { fnTe => + fnTe.getType match { + case Type.Fun.SimpleUniversal(us, argsT, resT) + if argsT.length == args.length => + infArgs.sequence + .flatMap { argsTE => + val argTypes = argsTE.map(_.getType) + // we can lift any quantification of the args + // outside of the function application + // We have to lift *before* substitution + val noshadows = + Type.freeBoundTyVars(resT :: argTypes.toList).toSet ++ + us.iterator.map(_._1) + val (optQ, liftArgs) = + liftQuantification(argsTE, noshadows) + + val liftArgTypes = liftArgs.map(_.getType) + Type.instantiate( + us.toList.toMap, + Type.Tuple(argsT.toList), + Type.Tuple(liftArgTypes.toList), + optQ.fold(Map.empty[Type.Var.Bound, Kind])( + _.vars.toList.toMap + ) + ) match { + case None => + /* + println(s"can't instantiate: ${ + Type.fullyResolvedDocument.document(fnTe.getType).render(80) + } to ${liftArgTypes.map(Type.fullyResolvedDocument.document(_).render(80))}") + */ + pureNone + case Some((frees, inst)) => + if (frees.nonEmpty) { + // TODO maybe we could handle this, but not yet + // seems like if the free vars are set to the same + // variable, then we can just lift it into the + // quantification + /* + println(s"remaining frees in ${ + Type.fullyResolvedDocument.document(fnTe.getType).render(80) + } to ${liftArgTypes.map(Type.fullyResolvedDocument.document(_).render(80))}: $frees") + */ + pureNone + } else { + val subMap = + inst.view.mapValues(_._2).toMap[Type.Var, Type] + val fnType0 = Type.Fun(liftArgTypes, resT) + val fnType1 = Type.substituteVar(fnType0, subMap) + val resType = Type.substituteVar(resT, subMap) + + val resTe = TypedExpr.App( + TypedExpr.Annotation(fnTe, fnType1), + liftArgs, + resType, + tag + ) + + val maybeQuant = optQ match { + case Some(q) => TypedExpr.Generic(q, resTe) + case None => resTe + } + + instSigma( + maybeQuant.getType, + exp, + region + ) + .map(co => Some(co(maybeQuant))) + } + } + } + case _ => + pureNone + } + } + }.flatSequence + def applyRhoExpect[A: HasRegion]( fn: Expr[A], args: NonEmptyList[Expr[A]], @@ -1524,64 +1656,17 @@ object Infer { zonkTypedExpr(typedTerm).map(coerce(_)) } case App(fn, args, tag) => - expect match { - case Expected.Check((rho, reg)) => - checkApply(fn, args, tag, rho, reg) - case inf @ Expected.Inf(_) => - (maybeSimple(fn), args.traverse(maybeSimple(_))) - .mapN { (infFn, infArgs) => - infFn.flatMap { fnTe => - fnTe.getType match { - case Type.Fun.SimpleUniversal(us, argsT, resT) - if argsT.length == args.length => - infArgs.sequence - .flatMap { argsTE => - val argTypes = argsTE.map(_.getType) - Type.instantiate( - us.toList.toMap, - Type.Tuple(argsT.toList), - Type.Tuple(argTypes.toList) - ) match { - case None => - pureNone - case Some((frees, inst)) => - if (frees.nonEmpty) { - // TODO maybe we could handle this, but not yet - pureNone - } else { - val subMap = inst.view.mapValues(_._2).toMap[Type.Var, Type] - val fnType0 = Type.Fun(argsT, resT) - val fnType1 = Type.substituteVar(fnType0, subMap) - val resType = Type.substituteVar(resT, subMap) - val resTe = TypedExpr.App( - TypedExpr.Annotation(fnTe, fnType1), - argsTE, - resType, - term.tag - ) - - instSigma( - resType, - inf, - HasRegion.region(term) - ) - .map(co => Some(co(resTe))) - } - } - } - case _ => - pureNone - } - } - } - .flatSequence - .flatMap { - case Some(te) => pure(te) - case None => + applyViaInst(fn, args, tag, HasRegion.region(tag), expect) + .flatMap { + case Some(te) => pure(te) + case None => + expect match { + case Expected.Check((rho, reg)) => + checkApply(fn, args, tag, rho, reg) + case inf @ Expected.Inf(_) => applyRhoExpect(fn, args, tag, inf) } - - } + } case Generic(tpes, in) => for { unSkol <- inferForAll(tpes, in) @@ -1747,7 +1832,12 @@ object Infer { simp.flatMap { te => te.getType match { case Type.ForAll(fas, in) => - Type.instantiate(fas.iterator.toMap, in, rho) match { + Type.instantiate( + fas.iterator.toMap, + in, + rho, + Map.empty + ) match { case Some((frees, subs)) if frees.isEmpty => // we know that substituting in gives rho // check kinds diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala index 24bbd18e6..39c21f860 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala @@ -69,6 +69,46 @@ object Type { def forallList: List[(Var.Bound, Kind)] def concat(that: Quantification): Quantification + // Return this quantification, where the vars avoid otherVars + def unshadow( + otherVars: Set[Var.Bound] + ): (Map[Var, TyVar], Quantification) = { + def unshadowNel( + nel: NonEmptyList[(Var.Bound, Kind)] + ): (Map[Var, TyVar], NonEmptyList[(Var.Bound, Kind)]) = { + val remap: Map[Var, Var.Bound] = { + val collisions = nel.toList.filter { case (b, _) => otherVars(b) } + val nonCollisions = nel.iterator.filterNot { case (b, _) => + otherVars(b) + } + val colMap = + alignBinders(collisions, otherVars ++ nonCollisions.map(_._1)) + colMap.iterator.map { case ((b, _), b1) => (b, b1) }.toMap + } + + val nel1 = nel.map { case bk @ (b, k) => + remap.get(b) match { + case None => bk + case Some(b1) => (b1, k) + } + } + (remap.view.mapValues(TyVar(_)).toMap, nel1) + } + + if (vars.exists { case (b, _) => otherVars(b) }) { + this match { + case Quantification.Dual(foralls, exists) => + val (mfa, fa) = unshadowNel(foralls) + val (mex, ex) = unshadowNel(exists) + (mfa ++ mex, Quantification.Dual(fa, ex)) + case Quantification.ForAll(forAll) => + unshadowNel(forAll).map(Quantification.ForAll(_)) + case Quantification.Exists(exists) => + unshadowNel(exists).map(Quantification.Exists(_)) + } + } else (Map.empty, this) + } + def filter(fn: Var.Bound => Boolean): Option[Quantification] = Quantification.fromLists( forallList.filter { case (b, _) => fn(b) }, @@ -503,7 +543,12 @@ object Type { /** Kind of the opposite of substitute: given a Map of vars, can we set those * vars to some Type and get from to match to exactly */ - def instantiate(vars: Map[Var.Bound, Kind], from: Type, to: Type): Option[ + def instantiate( + vars: Map[Var.Bound, Kind], + from: Type, + to: Type, + env: Map[Var.Bound, Kind] + ): Option[ ( SortedMap[Var.Bound, (Kind, Var.Bound)], SortedMap[Var.Bound, (Kind, Type)] @@ -540,17 +585,26 @@ object Type { opt match { case Unknown => to match { - case TyVar(toB: Var.Bound) => + case tv @ TyVar(toB: Var.Bound) => state.rightFrees.get(toB) match { case Some(toBKind) => if (Kind.leftSubsumesRight(kind, toBKind)) { Some(state.updated(b, (toBKind, Free(toB)))) } else None - case None => None + case None => + env.get(toB) match { + case Some(toBKind) + if (Kind.leftSubsumesRight(kind, toBKind)) => + Some(state.updated(b, (toBKind, Fixed(tv)))) + case _ => None + } // don't set to vars to non-free bound variables // this shouldn't happen in real inference } - case _ if hasNoUnboundVars(to) => + case _ + if freeBoundTyVars(to :: Nil) + .filterNot(env.keySet) + .isEmpty => Some(state.updated(b, (kind, Fixed(to)))) case _ => None } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index 0fd1d75e6..b9f6cbcab 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -1858,7 +1858,7 @@ refl_any: Sub[forall a. a, exists a. a] = refl_sub #ignore: Tup[forall a. Sub[a, a], Sub[forall a. a, exists a. a]] = Tup(refl_sub, refl_any) ignore = Tup(refl_sub, refl_any) """, - "Tup[forall a. Sub[a, a], Sub[forall a. a, exists a. a]]" + "forall a. Tup[Sub[a, a], Sub[forall a. a, exists a. a]]" ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala index d452ad9c3..eeaa93f2b 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala @@ -442,7 +442,7 @@ class TypeTest extends AnyFunSuite { forAll(NTypeGen.genDepth03, NTypeGen.genDepth03) { (t1, t2) => t1 match { case Type.ForAll(fas, t) => - Type.instantiate(fas.iterator.toMap, t, t2) match { + Type.instantiate(fas.iterator.toMap, t, t2, Map.empty) match { case Some((frees, subs)) => val t3 = Type.substituteVar( t, @@ -477,7 +477,7 @@ class TypeTest extends AnyFunSuite { def check(forall: String, matches: String, subs: List[(String, String)]) = { val Type.ForAll(fas, t) = parse(forall) val targ = parse(matches) - Type.instantiate(fas.iterator.toMap, t, targ) match { + Type.instantiate(fas.iterator.toMap, t, targ, Map.empty) match { case Some((_, subMap)) => assert(subMap.size == subs.size) subs.foreach { case (k, v) => @@ -492,7 +492,7 @@ class TypeTest extends AnyFunSuite { def noSub(forall: String, matches: String) = { val Type.ForAll(fas, t) = parse(forall) val targ = parse(matches) - val res = Type.instantiate(fas.iterator.toMap, t, targ) + val res = Type.instantiate(fas.iterator.toMap, t, targ, Map.empty) assert(res == None) } diff --git a/test_workspace/TypeConstraint.bosatsu b/test_workspace/TypeConstraint.bosatsu index 3dc7fa408..c2e66f592 100644 --- a/test_workspace/TypeConstraint.bosatsu +++ b/test_workspace/TypeConstraint.bosatsu @@ -77,4 +77,6 @@ refl_bottom1: Sub[forall a. a, forall a. a] = refl_sub refl_Int: Sub[forall a. a, Int] = refl_sub refl_any: Sub[forall a. a, exists a. a] = refl_sub refl_any1: Sub[exists a. a, exists a. a] = refl_sub -refl_Int_any: Sub[Int, exists a. a] = refl_sub \ No newline at end of file +refl_Int_any: Sub[Int, exists a. a] = refl_sub + +ignore = (refl_bottom, refl_bottom1, refl_Int, refl_any) \ No newline at end of file From 7632aae02fcfd80cf003dd9dafbdd527be7123d8 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Fri, 8 Mar 2024 17:26:19 -0700 Subject: [PATCH 4/5] increase substitution based subtyping --- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 73 +++++++++++++------ .../bykn/bosatsu/rankn/RankNInferTest.scala | 46 ++++++++++-- 2 files changed, 90 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 5e4d48b20..5bb52d9d5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -1174,7 +1174,7 @@ object Infer { left: Region, right: Region ): Option[Infer[TypedExpr.Coerce]] = - inferred match { + (inferred match { case Type.ForAll(vars, inT) => Type.instantiate(vars.iterator.toMap, inT, declared, Map.empty).map { case (_, subs) => @@ -1189,8 +1189,26 @@ object Infer { } } } - case _ => None - } + case _ => + None + }).orElse(declared match { + case Type.Exists(vars, inT) => + Type.instantiate(vars.iterator.toMap, inT, inferred, Map.empty).map { + case (_, subs) => + validateSubs(subs.toList, left, right) + .as { + new FunctionK[TypedExpr, TypedExpr] { + def apply[A](te: TypedExpr[A]): TypedExpr[A] = + // we apply the annotation here and let Normalization + // instantiate. We could explicitly have + // instantiation TypedExpr where you pass the variables to set + TypedExpr.Annotation(te, declared) + } + } + } + case _ => + None + }) // note, this is identical to subsCheckRho when declared is a Rho type def subsCheck( inferred: Type, @@ -1508,9 +1526,7 @@ object Infer { def applyViaInst[A: HasRegion]( fn: Expr[A], args: NonEmptyList[Expr[A]], - tag: A, - region: Region, - exp: Expected[(Type.Rho, Region)] + tag: A ): Infer[Option[TypedExpr[A]]] = (maybeSimple(fn), args.traverse(maybeSimple(_))).mapN { (infFn, infArgs) => @@ -1577,12 +1593,7 @@ object Infer { case None => resTe } - instSigma( - maybeQuant.getType, - exp, - region - ) - .map(co => Some(co(maybeQuant))) + pure(Some(maybeQuant)) } } } @@ -1648,17 +1659,36 @@ object Infer { res <- zonkTypedExpr(TypedExpr.Global(pack, name, vSigma, tag)) } yield coerce(res) case Annotation(App(fn, args, tag), resT, annTag) => - ( - checkApply(fn, args, tag, resT, region(annTag)), - instSigma(resT, expect, region(annTag)) - ) - .parFlatMapN { (typedTerm, coerce) => - zonkTypedExpr(typedTerm).map(coerce(_)) + applyViaInst(fn, args, tag) + .flatMap { + case Some(te) => + for { + co1 <- subsCheck( + te.getType, + resT, + region(tag), + region(annTag) + ) + co2 <- instSigma(resT, expect, region(annTag)) + z <- zonkTypedExpr(te) + } yield co2(co1(z)) + case None => + ( + checkApply(fn, args, tag, resT, region(annTag)), + instSigma(resT, expect, region(annTag)) + ) + .parFlatMapN { (typedTerm, coerce) => + zonkTypedExpr(typedTerm).map(coerce(_)) + } } case App(fn, args, tag) => - applyViaInst(fn, args, tag, HasRegion.region(tag), expect) + applyViaInst(fn, args, tag) .flatMap { - case Some(te) => pure(te) + case Some(te) => + for { + co <- instSigma(te.getType, expect, HasRegion.region(tag)) + z <- zonkTypedExpr(te) + } yield co(z) case None => expect match { case Expected.Check((rho, reg)) => @@ -2389,8 +2419,7 @@ object Infer { for { _ <- init - rhoT <- inferRho(e) - (rho, expTyRho) = rhoT + (rho, expTyRho) <- inferRho(e) q <- quantify(unifySelf(expTyRho), rho) } yield q } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index b9f6cbcab..ce770f93b 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -201,7 +201,13 @@ class RankNInferTest extends AnyFunSuite { ) assert(Type.metaTvs(tp :: Nil).isEmpty, s"illegal inferred type: $teStr") - assert(te.getType.sameAs(typeFrom(tpe)), s"found: ${te.repr.render(80)}") + val expectedTpe = typeFrom(tpe) + val expectedTpeStr = + Type.fullyResolvedDocument.document(expectedTpe).render(80) + assert( + te.getType.sameAs(expectedTpe), + s"$teStr != $expectedTpeStr\n\nfound: ${te.repr.render(80)}" + ) } // this could be used to test the string representation of expressions @@ -1602,9 +1608,9 @@ enum B: T, F struct Inv[a: *](item: a) any: exists a. a = T -x: Inv[exists a. a] = Inv(any) +x: exists a. Inv[a] = Inv(any) """, - "Inv[exists a. a]" + "exists a. Inv[a]" ) } @@ -1851,14 +1857,40 @@ f3: Foo[forall a. a] = f1 parseProgram( """ struct Sub[a: -*, b: +*](sub: forall f: +* -> *. f[a] -> f[b]) -struct Tup(a, b) +struct Tup(a, b, c, d) +struct Foo refl_sub: forall a. Sub[a, a] = Sub(x -> x) +refl_bottom: forall b. Sub[forall a. a, b] = refl_sub +refl_bottom1: Sub[forall a. a, forall a. a] = refl_sub +refl_Foo: Sub[forall a. a, Foo] = refl_sub refl_any: Sub[forall a. a, exists a. a] = refl_sub -#ignore: Tup[forall a. Sub[a, a], Sub[forall a. a, exists a. a]] = Tup(refl_sub, refl_any) -ignore = Tup(refl_sub, refl_any) +refl_any1: Sub[exists a. a, exists a. a] = refl_sub +refl_Foo_any: Sub[Foo, exists a. a] = refl_sub + +ignore = Tup(refl_bottom, refl_bottom1, refl_Foo, refl_any) """, - "forall a. Tup[Sub[a, a], Sub[forall a. a, exists a. a]]" + "forall a. Tup[Sub[forall a. a, a], Sub[forall a. a, forall a. a]," + + "Sub[forall a. a, Foo], Sub[forall a. a, exists a. a]]" + ) + + parseProgram( + """ +struct Sub[a: -*, b: +*](sub: forall f: +* -> *. f[a] -> f[b]) +struct Tup(a, b, c, d) +struct Foo + +refl_sub: forall a. Sub[a, a] = Sub(x -> x) +refl_bottom: forall b. Sub[forall a. a, b] = refl_sub +refl_bottom1: Sub[forall a. a, forall a. a] = refl_sub +refl_Foo: Sub[forall a. a, Foo] = refl_sub +refl_any: Sub[forall a. a, exists a. a] = refl_sub +refl_any1: Sub[exists a. a, exists a. a] = refl_sub +refl_Foo_any: Sub[Foo, exists a. a] = refl_sub + +ignore: exists a. a = Tup(refl_bottom, refl_bottom1, refl_Foo, refl_any) +""", + "exists a. a" ) } } From d0d643a3c1081b14c14e2336b2f6252faca05e18 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Sun, 10 Mar 2024 10:02:08 -0700 Subject: [PATCH 5/5] some improvements --- .../scala/org/bykn/bosatsu/TypedExpr.scala | 59 ++++++++++++++++++- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 50 ++-------------- .../org/bykn/bosatsu/TypedExprTest.scala | 17 ++++++ .../bykn/bosatsu/rankn/RankNInferTest.scala | 4 ++ 4 files changed, 82 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala index 3d2e37b60..eedc57b24 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala @@ -974,9 +974,7 @@ object TypedExpr { } private def allPatternTypes[N](p: Pattern[N, Type]): SortedSet[Type] = - p.traverseType(t => Writer[SortedSet[Type], Type](SortedSet(t), t)) - .run - ._1 + p.traverseType(t => Writer[SortedSet[Type], Type](SortedSet(t), t)).run._1 // Invariant, nel must have at least one item in common with quant.vars private def filterQuant( @@ -1390,6 +1388,16 @@ object TypedExpr { if Type.quantify(q, tpe).sameAs(term.getType) => // we not uncommonly add an annotation just to make a generic wrapper to get back where term + case Annotation(term, tpe) + if !q.vars.iterator + .map(_._1) + .exists( + Type.freeBoundTyVars(expr.getType :: Nil).toSet + ) => + // the variables may be free lower, but not here + val genTerm = normalizeQuantVars(q, term) + if (genTerm.getType.sameAs(tpe)) genTerm + else Annotation(normalizeQuantVars(q, term), tpe) case _ => import Type.Quantification._ // We cannot rebind to any used typed inside of expr, but we can reuse @@ -1478,4 +1486,49 @@ object TypedExpr { implicit def typedExprHasRegion[T: HasRegion]: HasRegion[TypedExpr[T]] = HasRegion.instance[TypedExpr[T]](e => HasRegion.region(e.tag)) + + // noshadow must include any free vars of args + def liftQuantification[A]( + args: NonEmptyList[TypedExpr[A]], + noshadow: Set[Type.Var.Bound] + ): ( + Option[Type.Quantification], + NonEmptyList[TypedExpr[A]] + ) = { + + val htype = args.head.getType + val (oq, rest) = NonEmptyList.fromList(args.tail) match { + case Some(neTail) => + val (oq, rest) = liftQuantification(neTail, noshadow) + (oq, rest.toList) + case None => + (None, Nil) + } + + htype match { + case Type.Quantified(q, rho) => + oq match { + case Some(qtail) => + // we have to unshadow with noshadow + all the vars in the tail + val (map, q1) = + q.unshadow(noshadow ++ qtail.vars.toList.iterator.map(_._1)) + val rho1 = Type.substituteRhoVar(rho, map) + ( + Some(q1.concat(qtail)), + NonEmptyList(TypedExpr.Annotation(args.head, rho1), rest) + ) + case None => + val (map, q1) = q.unshadow(noshadow) + val rho1 = Type.substituteRhoVar(rho, map) + ( + Some(q1), + NonEmptyList(TypedExpr.Annotation(args.head, rho1), rest) + ) + } + + case _ => + (oq, NonEmptyList(args.head, rest)) + } + } + } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 5bb52d9d5..e305357a0 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -1207,6 +1207,10 @@ object Infer { } } case _ => + // TODO: we should be able to handle Dual quantification which could + // solve more cases. The challenge is existentials and universals appear + // on different sides, so cases where both need solutions can't be done + // with the current method that only solves one direction now. None }) // note, this is identical to subsCheckRho when declared is a Rho type @@ -1479,50 +1483,6 @@ object Infer { } } - // noshadow must include any free vars of args - def liftQuantification[A]( - args: NonEmptyList[TypedExpr[A]], - noshadow: Set[Type.Var.Bound] - ): ( - Option[Type.Quantification], - NonEmptyList[TypedExpr[A]] - ) = { - - val htype = args.head.getType - val (oq, rest) = NonEmptyList.fromList(args.tail) match { - case Some(neTail) => - val (oq, rest) = liftQuantification(neTail, noshadow) - (oq, rest.toList) - case None => - (None, Nil) - } - - htype match { - case Type.Quantified(q, rho) => - oq match { - case Some(qtail) => - // we have to unshadow with noshadow + all the vars in the tail - val (map, q1) = - q.unshadow(noshadow ++ qtail.vars.toList.iterator.map(_._1)) - val rho1 = Type.substituteRhoVar(rho, map) - ( - Some(q1.concat(qtail)), - NonEmptyList(TypedExpr.Annotation(args.head, rho1), rest) - ) - case None => - val (map, q1) = q.unshadow(noshadow) - val rho1 = Type.substituteRhoVar(rho, map) - ( - Some(q1), - NonEmptyList(TypedExpr.Annotation(args.head, rho1), rest) - ) - } - - case _ => - (oq, NonEmptyList(args.head, rest)) - } - } - def applyViaInst[A: HasRegion]( fn: Expr[A], args: NonEmptyList[Expr[A]], @@ -1544,7 +1504,7 @@ object Infer { Type.freeBoundTyVars(resT :: argTypes.toList).toSet ++ us.iterator.map(_._1) val (optQ, liftArgs) = - liftQuantification(argsTE, noshadows) + TypedExpr.liftQuantification(argsTE, noshadows) val liftArgTypes = liftArgs.map(_.getType) Type.instantiate( diff --git a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala index 7ed0b68b7..7834ac68c 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala @@ -995,4 +995,21 @@ x = ( ) } } + + test("TypedExpr.liftQuantification makes all args Rho types") { + + forAll(Generators.smallNonEmptyList(genTypedExpr, 10), + Gen.containerOf[Set, Type.Var.Bound](NTypeGen.genBound)) { (tes, avoid) => + + val (optQuant, args) = TypedExpr.liftQuantification(tes, avoid) + args.toList.foreach { te => + te.getType match { + case _: Type.Rho => () + case notRho => fail(s"expected: $te to have rho type, got: $notRho") + } + } + val allRhos = tes.forall(_.getType.isInstanceOf[Type.Rho]) + assert(allRhos == optQuant.isEmpty) + } + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index ce770f93b..31200abad 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -1610,6 +1610,10 @@ struct Inv[a: *](item: a) any: exists a. a = T x: exists a. Inv[a] = Inv(any) """, + // TODO: it would be nice to be able to annotate this as + // Inv[exists a. a] and get that to pass too + // even though, I think exists a. Inv[a] is a tighter type + // "exists a. Inv[a]" ) }