diff --git a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala index 2214ae97..1fce5045 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala @@ -109,7 +109,7 @@ object DefRecursionCheck { case Def(defn) => // make this the same shape as a in declaration checkDef(TopLevel, defn.copy(result = (defn.result, ()))) - case ExternalDef(_, _, _) => + case ExternalDef(_, _, _, _) => unitValid } case _ => unitValid diff --git a/core/src/main/scala/org/bykn/bosatsu/Package.scala b/core/src/main/scala/org/bykn/bosatsu/Package.scala index 5ed481ff..a2b79b6d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Package.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Package.scala @@ -273,6 +273,11 @@ object Package { Type.Const.Defined(p, TypeName(tds.name)) -> tds.region }.toMap + lazy val extDefRegions: Map[Identifier.Bindable, Region] = + stmts.iterator.collect { case ed: Statement.ExternalDef => + ed.name -> ed.region + }.toMap + optProg.flatMap { case Program((importedTypeEnv, parsedTypeEnv), lets, extDefs, _) => val inferVarianceParsed @@ -336,8 +341,17 @@ object Package { errs.map(PackageError.TotalityCheckError(p, _)) } + val theseExternals = + parsedTypeEnv + .externalDefs + .collect { case (pack, b, t) if pack === p => + // by construction this has to have all the regions + (b, (t, extDefRegions(b))) + } + .toMap + val inferenceEither = Infer - .typeCheckLets(p, lets) + .typeCheckLets(p, lets, theseExternals) .runFully( withFQN, Referant.typeConstructors(imps) ++ typeEnv.typeConstructors, diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageError.scala b/core/src/main/scala/org/bykn/bosatsu/PackageError.scala index a029dd4f..521cd695 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageError.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageError.scala @@ -363,6 +363,20 @@ object PackageError { ) + Doc.hardLine + context + (doc, Some(region)) + case Infer.Error.KindExpectedType(tpe, kind, region) => + val tmap = showTypes(pack, tpe :: Nil) + val context = + lm.showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region + val doc = Doc.text("expected type ") + + tmap(tpe) + Doc.text( + " to have kind *, which is to say be a valid value, but it is kind " + ) + Kind.toDoc(kind) + Doc.hardLine + + context + (doc, Some(region)) case Infer.Error.KindInvalidApply(applied, leftK, rightK, region) => val leftT = applied.on diff --git a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala index a31c37ff..428a5a38 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala @@ -157,7 +157,7 @@ final class SourceConverter( SourceConverter.InvalidDefTypeParameters( args, freeVarsList, - ds, + Right(ds), region ), gen @@ -1227,7 +1227,7 @@ final class SourceConverter( values: List[Statement.ValueStatement] ): Result[Unit] = { val extDefNames = - values.collect { case ed @ Statement.ExternalDef(name, _, _) => + values.collect { case ed @ Statement.ExternalDef(name, _, _, _) => (name, ed.region) } @@ -1259,7 +1259,7 @@ final class SourceConverter( s match { case b @ Statement.Bind(_) => Some(Left(b)) case d @ Statement.Def(_) => Some(Right(d)) - case Statement.ExternalDef(_, _, _) => None + case Statement.ExternalDef(_, _, _, _) => None } def checkDefBind(s: Statement.ValueStatement): Result[Unit] = @@ -1391,7 +1391,7 @@ final class SourceConverter( stmts.toList.flatMap { case d @ Def(_) => (d.defstatement.name, RecursionKind.Recursive, Left(d)) :: Nil - case ExternalDef(_, _, _) => + case ExternalDef(_, _, _, _) => // we don't allow external defs to shadow at all, so skip it here Nil case Bind(BindingStatement(bound, decl, _)) => @@ -1456,7 +1456,7 @@ final class SourceConverter( } val withEx: List[Either[ExternalDef, Flattened]] = - stmts.collect { case e @ ExternalDef(_, _, _) => Left(e) }.toList ::: + stmts.collect { case e @ ExternalDef(_, _, _, _) => Left(e) }.toList ::: flatIn.map { case (b, _, Left(d @ Def(dstmt))) => Right(Left(Def(dstmt.copy(name = b))(d.region))) @@ -1513,7 +1513,7 @@ final class SourceConverter( (boundName, rec, l1) :: Nil } (topBound1, r) - case Left(ExternalDef(n, _, _)) => + case Left(ExternalDef(n, _, _, _)) => (topBound + n, success(Nil)) } }(SourceConverter.parallelIor)).map(_.flatten) @@ -1526,7 +1526,7 @@ final class SourceConverter( ], List[Statement]]] = { val stmts = Statement.valuesOf(ss).toList stmts - .collect { case ed @ Statement.ExternalDef(name, params, result) => + .collect { case ed @ Statement.ExternalDef(name, ta, params, result) => ( params.traverse(p => toType(p._2, ed.region)), toType(result, ed.region) @@ -1547,7 +1547,7 @@ final class SourceConverter( } } } - .map { (tpe: rankn.Type) => + .flatMap { (tpe: rankn.Type) => val freeVars = rankn.Type.freeTyVars(tpe :: Nil) // these vars were parsed so they are never skolem vars val freeBound = freeVars.map { @@ -1557,10 +1557,34 @@ final class SourceConverter( sys.error(s"invariant violation: parsed a skolem var: $s") // $COVERAGE-ON$ } - // TODO: Kind support parsing kinds - val maybeForAll = - rankn.Type.forAll(freeBound.map(n => (n, Kind.Type)), tpe) - (name, maybeForAll) + val finalTpe = ta match { + case None => + success(rankn.Type.forAll(freeBound.map(n => (n, Kind.Type)), tpe)) + case Some(frees0) => + val frees = frees0.map { case (ref, optK) => ref.toBoundVar -> optK } + if (frees.iterator.map(_._1).toSet === freeBound.toSet[rankn.Type.Var.Bound]) { + success(rankn.Type.forAll(frees.map { + case (v, None) => (v, Kind.Type) + case (v, Some(k)) => (v, k) + }, tpe)) + } + else { + val kindMap = frees.iterator.collect { case (v, Some(k)) => (v, k) }.toMap + val vs = freeBound.map { v => (v, kindMap.getOrElse(v, Kind.Type)) } + val t = rankn.Type.forAll(vs, tpe) + SourceConverter.partial( + SourceConverter.InvalidDefTypeParameters( + frees0, + freeBound, + Left(ed), + ed.region + ), + t + ) + } + } + + finalTpe.map(name -> _) } } // TODO: we could implement Iterable[Ior[A, B]] => Ior[A, Iterble[B]] @@ -1887,10 +1911,21 @@ object SourceConverter { final case class InvalidDefTypeParameters[B]( declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind])], free: List[Type.Var.Bound], - defstmt: DefStatement[Pattern.Parsed, B], + defstmt: Either[Statement.ExternalDef, DefStatement[Pattern.Parsed, B]], region: Region ) extends Error { + def name: Identifier.Bindable = defstmt match { + case Right(ds) => ds.name + case Left(ed) => ed.name + } + + def expectation: String = defstmt match { + case Right(_) => "a subset of" + case Left(_) => "the same as" + } + + def message = { def tstr(l: List[Type.Var.Bound]): String = l.iterator.map(_.name).mkString("[", ", ", "]") @@ -1903,7 +1938,7 @@ object SourceConverter { .renderTrim(80) val freeStr = tstr(free) - s"${defstmt.name.asString} found declared types: $decl, not a subset of $freeStr" + s"${name.asString} found declared types: $decl, not $expectation $freeStr" } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Statement.scala b/core/src/main/scala/org/bykn/bosatsu/Statement.scala index d22c1add..f536cf98 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Statement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Statement.scala @@ -34,8 +34,8 @@ sealed abstract class Statement { Struct(nm, typeArgs, args)(r) case Enum(nm, typeArgs, parts) => Enum(nm, typeArgs, parts)(r) - case ExternalDef(name, args, res) => - ExternalDef(name, args, res)(r) + case ExternalDef(name, ta, args, res) => + ExternalDef(name, ta, args, res)(r) case ExternalStruct(nm, targs) => ExternalStruct(nm, targs)(r) } @@ -83,7 +83,7 @@ object Statement { case Bind(BindingStatement(bound, _, _)) => bound.names // TODO Keep identifiers case Def(defstatement) => defstatement.name :: Nil - case ExternalDef(name, _, _) => name :: Nil + case ExternalDef(name, _, _, _) => name :: Nil } /** These are all the free bindable names in the right hand side of this @@ -98,7 +98,7 @@ object Statement { (innerFrees - defstatement.name) -- defstatement.args.toList.flatMap( _.patternNames ) - case ExternalDef(_, _, _) => SortedSet.empty + case ExternalDef(_, _, _, _) => SortedSet.empty } /** These are all the bindings, free or not, in this Statement @@ -109,7 +109,7 @@ object Statement { case Def(defstatement) => (defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList .flatMap(_.patternNames) - case ExternalDef(name, _, _) => SortedSet(name) + case ExternalDef(name, _, _, _) => SortedSet(name) } } @@ -126,6 +126,7 @@ object Statement { extends ValueStatement case class ExternalDef( name: Bindable, + typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind])]], params: List[(Bindable, TypeRef)], result: TypeRef )(val region: Region) @@ -230,6 +231,10 @@ object Statement { val externalDef = { + val kindAnnot: P[Kind] = + (maybeSpace.soft.with1 *> (P.char(':') *> maybeSpace *> Kind.parser)) + val typeParams = TypeRef.typeParams(kindAnnot.?).? + val args = P.char('(') *> maybeSpace *> argParser.nonEmptyList <* maybeSpace <* P .char(')') @@ -239,16 +244,16 @@ object Statement { (((keySpace( "def" - ) *> Identifier.bindableParser ~ args ~ result).region) <* toEOL) - .map { case (region, ((name, args), resType)) => - ExternalDef(name, args.toList, resType)(region) + ) *> Identifier.bindableParser ~ typeParams ~ args ~ result).region) <* toEOL) + .map { case (region, (((name, tps), args), resType)) => + ExternalDef(name, tps, args.toList, resType)(region) } } val externalVal = (argParser <* toEOL).region .map { case (region, (name, resType)) => - ExternalDef(name, Nil, resType)(region) + ExternalDef(name, None, Nil, resType)(region) } keySpace("external") *> P.oneOf( @@ -385,11 +390,19 @@ object Statement { .char(':') + colonSep + indentedCons + Doc.line - case ExternalDef(name, Nil, res) => + case ExternalDef(name, None, Nil, res) => Doc.text("external ") + Document[Bindable].document(name) + Doc.text( ": " ) + res.toDoc + Doc.line - case ExternalDef(name, args, res) => + case ExternalDef(name, tps, args, res) => + val taDoc = tps match { + case None => Doc.empty + case Some(ta) => + TypeRef.docTypeArgs(ta.toList) { + case None => Doc.empty + case Some(k) => colonSpace + Kind.toDoc(k) + } + } val argDoc = { val da = Doc.intercalate( Doc.text(", "), @@ -401,7 +414,7 @@ object Statement { } Doc.text("external def ") + Document[Bindable].document( name - ) + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line + ) + taDoc + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line case ExternalStruct(nm, typeArgs) => val taDoc = TypeRef.docTypeArgs(typeArgs.toList) { diff --git a/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala b/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala index 15abe8eb..886c92dc 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala @@ -21,7 +21,7 @@ object TypeRefConverter { import TypeRef._ t match { - case TypeVar(v) => Applicative[F].pure(TyVar(Type.Var.Bound(v))) + case tv @ TypeVar(_) => Applicative[F].pure(TyVar(tv.toBoundVar)) case TypeName(n) => nameToType(n.ident).map(TyConst(_)) case TypeArrow(as, b) => (as.traverse(toType(_)), toType(b)).mapN(Fun(_, _)) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index e305357a..4a07e343 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -182,6 +182,11 @@ object Infer { rightK: Kind, region: Region ) extends TypeError + case class KindExpectedType( + tpe: Type, + kind: Kind.Cons, + region: Region + ) extends TypeError case class KindMismatch( target: Type, targetKind: Kind, @@ -2614,7 +2619,8 @@ object Infer { */ def typeCheckLets[A: HasRegion]( pack: PackageName, - ls: List[(Bindable, RecursionKind, Expr[A])] + ls: List[(Bindable, RecursionKind, Expr[A])], + externals: Map[Bindable, (Type, Region)] ): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = { // Group together lets that don't include each other to get more type errors // if we can @@ -2655,7 +2661,22 @@ object Infer { else Some(bs :+ item) } - run(groups) + val checkExternals = + GetEnv.flatMap { env => + externals + .toList + .sortBy { case (_, (_, region)) => region } + .parTraverse_ { case (_, (t, region)) => + env.getKind(t, region) match { + case Right(Kind.Type) => unit + case Right(cons @ Kind.Cons(_, _)) => + fail(Error.KindExpectedType(t, cons, region)) + case Left(err) => fail(err) + } + } + } + + run(groups).parProductL(checkExternals) } /** This is useful to testing purposes. diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index 3eea45f7..626a93ec 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -3990,4 +3990,21 @@ test = TestSuite("bases", 12 ) } + + test("external defs with explicit type parameters exactly match") { + val testCode = """ +package ErrorCheck + +external def foo[b](lst: List[a]) -> a + +""" + evalFail(List(testCode)) { + case kie @ PackageError.SourceConverterErrorsIn(_, _, _) => + val message = kie.message(Map.empty, Colorize.None) + assert(message.contains("Region(30,59)")) + assert(message.contains("[b], not the same as [a]")) + assert(testCode.substring(30, 59) == "def foo[b](lst: List[a]) -> a") + () + } + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/Gen.scala b/core/src/test/scala/org/bykn/bosatsu/Gen.scala index 363c4f32..64b7a588 100644 --- a/core/src/test/scala/org/bykn/bosatsu/Gen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/Gen.scala @@ -1120,11 +1120,13 @@ object Generators { val genExternalDef: Gen[Statement] = for { name <- bindIdentGen + tas0 <- Gen.option(smallList(Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind)))) argc <- Gen.choose(0, 5) argG = Gen.zip(bindIdentGen, typeRefGen) args <- Gen.listOfN(argc, argG) + tas = if (args.isEmpty) None else tas0 res <- typeRefGen - } yield Statement.ExternalDef(name, args, res)(emptyRegion) + } yield Statement.ExternalDef(name, tas.flatMap(NonEmptyList.fromList(_)), args, res)(emptyRegion) val genEnum: Gen[Statement] = for { diff --git a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala index 4f050761..7d1c0db2 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala @@ -1716,6 +1716,13 @@ external def foo(i: Integer, b: a) -> String external def foo2(i: Integer, b: a) -> String """ ) + roundTrip( + Statement.parser, + """# header +external def foo[a](i: Integer, b: a) -> String +external def foo_co[a: +* -> *](i: Integer, b: a) -> String +""") + } test("we can parse any package") { diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index 1fd7eb8b..6a8da993 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -126,7 +126,8 @@ class RankNInferTest extends AnyFunSuite { testPackage, terms.map { case (k, v, _) => (Identifier.Name(k), RecursionKind.NonRecursive, v) - } + }, + Map.empty ) .runFully(withBools, boolTypes, Type.builtInKinds) match { case Left(err) => assert(false, err) @@ -224,7 +225,7 @@ class RankNInferTest extends AnyFunSuite { fail( "expected an invalid program, but got:\n\n" + program.lets .map { case (b, r, t) => - s"$b: $r = ${t.repr}" + s"$b: $r = ${t.repr.render(80)}" } .mkString("\n\n") ) @@ -1897,4 +1898,32 @@ ignore: exists a. a = Tup(refl_bottom, refl_bottom1, refl_Foo, refl_any) "exists a. a" ) } + + test("test external def with kinds") { + parseProgram(""" +struct Foo +external def foo[f: * -> *](f: f[Foo]) -> Foo + +struct Box[a](item: a) + +f = foo(Box(Foo)) + """, "Foo") + } + + test("ill kinded external defs are not allowed") { + parseProgramIllTyped("""# +struct Foo +external def foo[f: * -> *](function: f) -> f + +f = Foo +""") + + parseProgramIllTyped("""# +struct Box[a](item: a) +external foo: Box + +struct Foo +f = Foo +""") + } }