Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve external def parsing #1188

Merged
merged 6 commits into from Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -109,7 +109,7 @@ object DefRecursionCheck {
case Def(defn) =>
// make this the same shape as a in declaration
checkDef(TopLevel, defn.copy(result = (defn.result, ())))
case ExternalDef(_, _, _) =>
case ExternalDef(_, _, _, _) =>
unitValid
}
case _ => unitValid
Expand Down
16 changes: 15 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/Package.scala
Expand Up @@ -273,6 +273,11 @@ object Package {
Type.Const.Defined(p, TypeName(tds.name)) -> tds.region
}.toMap

lazy val extDefRegions: Map[Identifier.Bindable, Region] =
stmts.iterator.collect { case ed: Statement.ExternalDef =>
ed.name -> ed.region
}.toMap

optProg.flatMap {
case Program((importedTypeEnv, parsedTypeEnv), lets, extDefs, _) =>
val inferVarianceParsed
Expand Down Expand Up @@ -336,8 +341,17 @@ object Package {
errs.map(PackageError.TotalityCheckError(p, _))
}

val theseExternals =
parsedTypeEnv
.externalDefs
.collect { case (pack, b, t) if pack === p =>
// by construction this has to have all the regions
(b, (t, extDefRegions(b)))
}
.toMap

val inferenceEither = Infer
.typeCheckLets(p, lets)
.typeCheckLets(p, lets, theseExternals)
.runFully(
withFQN,
Referant.typeConstructors(imps) ++ typeEnv.typeConstructors,
Expand Down
14 changes: 14 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/PackageError.scala
Expand Up @@ -363,6 +363,20 @@ object PackageError {
) + Doc.hardLine +
context

(doc, Some(region))
case Infer.Error.KindExpectedType(tpe, kind, region) =>
val tmap = showTypes(pack, tpe :: Nil)
val context =
lm.showRegion(region, 2, errColor)
.getOrElse(
Doc.str(region)
) // we should highlight the whole region
val doc = Doc.text("expected type ") +
tmap(tpe) + Doc.text(
" to have kind *, which is to say be a valid value, but it is kind "
) + Kind.toDoc(kind) + Doc.hardLine +
context

(doc, Some(region))
case Infer.Error.KindInvalidApply(applied, leftK, rightK, region) =>
val leftT = applied.on
Expand Down
63 changes: 49 additions & 14 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Expand Up @@ -157,7 +157,7 @@ final class SourceConverter(
SourceConverter.InvalidDefTypeParameters(
args,
freeVarsList,
ds,
Right(ds),
region
),
gen
Expand Down Expand Up @@ -1227,7 +1227,7 @@ final class SourceConverter(
values: List[Statement.ValueStatement]
): Result[Unit] = {
val extDefNames =
values.collect { case ed @ Statement.ExternalDef(name, _, _) =>
values.collect { case ed @ Statement.ExternalDef(name, _, _, _) =>
(name, ed.region)
}

Expand Down Expand Up @@ -1259,7 +1259,7 @@ final class SourceConverter(
s match {
case b @ Statement.Bind(_) => Some(Left(b))
case d @ Statement.Def(_) => Some(Right(d))
case Statement.ExternalDef(_, _, _) => None
case Statement.ExternalDef(_, _, _, _) => None
}

def checkDefBind(s: Statement.ValueStatement): Result[Unit] =
Expand Down Expand Up @@ -1391,7 +1391,7 @@ final class SourceConverter(
stmts.toList.flatMap {
case d @ Def(_) =>
(d.defstatement.name, RecursionKind.Recursive, Left(d)) :: Nil
case ExternalDef(_, _, _) =>
case ExternalDef(_, _, _, _) =>
// we don't allow external defs to shadow at all, so skip it here
Nil
case Bind(BindingStatement(bound, decl, _)) =>
Expand Down Expand Up @@ -1456,7 +1456,7 @@ final class SourceConverter(
}

val withEx: List[Either[ExternalDef, Flattened]] =
stmts.collect { case e @ ExternalDef(_, _, _) => Left(e) }.toList :::
stmts.collect { case e @ ExternalDef(_, _, _, _) => Left(e) }.toList :::
flatIn.map {
case (b, _, Left(d @ Def(dstmt))) =>
Right(Left(Def(dstmt.copy(name = b))(d.region)))
Expand Down Expand Up @@ -1513,7 +1513,7 @@ final class SourceConverter(
(boundName, rec, l1) :: Nil
}
(topBound1, r)
case Left(ExternalDef(n, _, _)) =>
case Left(ExternalDef(n, _, _, _)) =>
(topBound + n, success(Nil))
}
}(SourceConverter.parallelIor)).map(_.flatten)
Expand All @@ -1526,7 +1526,7 @@ final class SourceConverter(
], List[Statement]]] = {
val stmts = Statement.valuesOf(ss).toList
stmts
.collect { case ed @ Statement.ExternalDef(name, params, result) =>
.collect { case ed @ Statement.ExternalDef(name, ta, params, result) =>
(
params.traverse(p => toType(p._2, ed.region)),
toType(result, ed.region)
Expand All @@ -1547,7 +1547,7 @@ final class SourceConverter(
}
}
}
.map { (tpe: rankn.Type) =>
.flatMap { (tpe: rankn.Type) =>
val freeVars = rankn.Type.freeTyVars(tpe :: Nil)
// these vars were parsed so they are never skolem vars
val freeBound = freeVars.map {
Expand All @@ -1557,10 +1557,34 @@ final class SourceConverter(
sys.error(s"invariant violation: parsed a skolem var: $s")
// $COVERAGE-ON$
}
// TODO: Kind support parsing kinds
val maybeForAll =
rankn.Type.forAll(freeBound.map(n => (n, Kind.Type)), tpe)
(name, maybeForAll)
val finalTpe = ta match {
case None =>
success(rankn.Type.forAll(freeBound.map(n => (n, Kind.Type)), tpe))
case Some(frees0) =>
val frees = frees0.map { case (ref, optK) => ref.toBoundVar -> optK }
if (frees.iterator.map(_._1).toSet === freeBound.toSet[rankn.Type.Var.Bound]) {
success(rankn.Type.forAll(frees.map {
case (v, None) => (v, Kind.Type)
case (v, Some(k)) => (v, k)
}, tpe))
}
else {
val kindMap = frees.iterator.collect { case (v, Some(k)) => (v, k) }.toMap
val vs = freeBound.map { v => (v, kindMap.getOrElse(v, Kind.Type)) }
val t = rankn.Type.forAll(vs, tpe)
SourceConverter.partial(
SourceConverter.InvalidDefTypeParameters(
frees0,
freeBound,
Left(ed),
ed.region
),
t
)
}
}

finalTpe.map(name -> _)
}
}
// TODO: we could implement Iterable[Ior[A, B]] => Ior[A, Iterble[B]]
Expand Down Expand Up @@ -1887,10 +1911,21 @@ object SourceConverter {
final case class InvalidDefTypeParameters[B](
declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind])],
free: List[Type.Var.Bound],
defstmt: DefStatement[Pattern.Parsed, B],
defstmt: Either[Statement.ExternalDef, DefStatement[Pattern.Parsed, B]],
region: Region
) extends Error {

def name: Identifier.Bindable = defstmt match {
case Right(ds) => ds.name
case Left(ed) => ed.name
}

def expectation: String = defstmt match {
case Right(_) => "a subset of"
case Left(_) => "the same as"
}


def message = {
def tstr(l: List[Type.Var.Bound]): String =
l.iterator.map(_.name).mkString("[", ", ", "]")
Expand All @@ -1903,7 +1938,7 @@ object SourceConverter {
.renderTrim(80)

val freeStr = tstr(free)
s"${defstmt.name.asString} found declared types: $decl, not a subset of $freeStr"
s"${name.asString} found declared types: $decl, not $expectation $freeStr"
}
}

Expand Down
37 changes: 25 additions & 12 deletions core/src/main/scala/org/bykn/bosatsu/Statement.scala
Expand Up @@ -34,8 +34,8 @@ sealed abstract class Statement {
Struct(nm, typeArgs, args)(r)
case Enum(nm, typeArgs, parts) =>
Enum(nm, typeArgs, parts)(r)
case ExternalDef(name, args, res) =>
ExternalDef(name, args, res)(r)
case ExternalDef(name, ta, args, res) =>
ExternalDef(name, ta, args, res)(r)
case ExternalStruct(nm, targs) =>
ExternalStruct(nm, targs)(r)
}
Expand Down Expand Up @@ -83,7 +83,7 @@ object Statement {
case Bind(BindingStatement(bound, _, _)) =>
bound.names // TODO Keep identifiers
case Def(defstatement) => defstatement.name :: Nil
case ExternalDef(name, _, _) => name :: Nil
case ExternalDef(name, _, _, _) => name :: Nil
}

/** These are all the free bindable names in the right hand side of this
Expand All @@ -98,7 +98,7 @@ object Statement {
(innerFrees - defstatement.name) -- defstatement.args.toList.flatMap(
_.patternNames
)
case ExternalDef(_, _, _) => SortedSet.empty
case ExternalDef(_, _, _, _) => SortedSet.empty
}

/** These are all the bindings, free or not, in this Statement
Expand All @@ -109,7 +109,7 @@ object Statement {
case Def(defstatement) =>
(defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList
.flatMap(_.patternNames)
case ExternalDef(name, _, _) => SortedSet(name)
case ExternalDef(name, _, _, _) => SortedSet(name)
}
}

Expand All @@ -126,6 +126,7 @@ object Statement {
extends ValueStatement
case class ExternalDef(
name: Bindable,
typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind])]],
params: List[(Bindable, TypeRef)],
result: TypeRef
)(val region: Region)
Expand Down Expand Up @@ -230,6 +231,10 @@ object Statement {

val externalDef = {

val kindAnnot: P[Kind] =
(maybeSpace.soft.with1 *> (P.char(':') *> maybeSpace *> Kind.parser))
val typeParams = TypeRef.typeParams(kindAnnot.?).?

val args =
P.char('(') *> maybeSpace *> argParser.nonEmptyList <* maybeSpace <* P
.char(')')
Expand All @@ -239,16 +244,16 @@ object Statement {

(((keySpace(
"def"
) *> Identifier.bindableParser ~ args ~ result).region) <* toEOL)
.map { case (region, ((name, args), resType)) =>
ExternalDef(name, args.toList, resType)(region)
) *> Identifier.bindableParser ~ typeParams ~ args ~ result).region) <* toEOL)
.map { case (region, (((name, tps), args), resType)) =>
ExternalDef(name, tps, args.toList, resType)(region)
}
}

val externalVal =
(argParser <* toEOL).region
.map { case (region, (name, resType)) =>
ExternalDef(name, Nil, resType)(region)
ExternalDef(name, None, Nil, resType)(region)
}

keySpace("external") *> P.oneOf(
Expand Down Expand Up @@ -385,11 +390,19 @@ object Statement {
.char(':') +
colonSep +
indentedCons + Doc.line
case ExternalDef(name, Nil, res) =>
case ExternalDef(name, None, Nil, res) =>
Doc.text("external ") + Document[Bindable].document(name) + Doc.text(
": "
) + res.toDoc + Doc.line
case ExternalDef(name, args, res) =>
case ExternalDef(name, tps, args, res) =>
val taDoc = tps match {
case None => Doc.empty
case Some(ta) =>
TypeRef.docTypeArgs(ta.toList) {
case None => Doc.empty
case Some(k) => colonSpace + Kind.toDoc(k)
}
}
val argDoc = {
val da = Doc.intercalate(
Doc.text(", "),
Expand All @@ -401,7 +414,7 @@ object Statement {
}
Doc.text("external def ") + Document[Bindable].document(
name
) + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line
) + taDoc + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line
case ExternalStruct(nm, typeArgs) =>
val taDoc =
TypeRef.docTypeArgs(typeArgs.toList) {
Expand Down
Expand Up @@ -21,7 +21,7 @@ object TypeRefConverter {
import TypeRef._

t match {
case TypeVar(v) => Applicative[F].pure(TyVar(Type.Var.Bound(v)))
case tv @ TypeVar(_) => Applicative[F].pure(TyVar(tv.toBoundVar))
case TypeName(n) => nameToType(n.ident).map(TyConst(_))
case TypeArrow(as, b) =>
(as.traverse(toType(_)), toType(b)).mapN(Fun(_, _))
Expand Down
25 changes: 23 additions & 2 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Expand Up @@ -182,6 +182,11 @@ object Infer {
rightK: Kind,
region: Region
) extends TypeError
case class KindExpectedType(
tpe: Type,
kind: Kind.Cons,
region: Region
) extends TypeError
case class KindMismatch(
target: Type,
targetKind: Kind,
Expand Down Expand Up @@ -2614,7 +2619,8 @@ object Infer {
*/
def typeCheckLets[A: HasRegion](
pack: PackageName,
ls: List[(Bindable, RecursionKind, Expr[A])]
ls: List[(Bindable, RecursionKind, Expr[A])],
externals: Map[Bindable, (Type, Region)]
): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = {
// Group together lets that don't include each other to get more type errors
// if we can
Expand Down Expand Up @@ -2655,7 +2661,22 @@ object Infer {
else Some(bs :+ item)
}

run(groups)
val checkExternals =
GetEnv.flatMap { env =>
externals
.toList
.sortBy { case (_, (_, region)) => region }
.parTraverse_ { case (_, (t, region)) =>
env.getKind(t, region) match {
case Right(Kind.Type) => unit
case Right(cons @ Kind.Cons(_, _)) =>
fail(Error.KindExpectedType(t, cons, region))
case Left(err) => fail(err)
}
}
}

run(groups).parProductL(checkExternals)
}

/** This is useful to testing purposes.
Expand Down
17 changes: 17 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Expand Up @@ -3990,4 +3990,21 @@ test = TestSuite("bases",
12
)
}

test("external defs with explicit type parameters exactly match") {
val testCode = """
package ErrorCheck

external def foo[b](lst: List[a]) -> a

"""
evalFail(List(testCode)) {
case kie @ PackageError.SourceConverterErrorsIn(_, _, _) =>
val message = kie.message(Map.empty, Colorize.None)
assert(message.contains("Region(30,59)"))
assert(message.contains("[b], not the same as [a]"))
assert(testCode.substring(30, 59) == "def foo[b](lst: List[a]) -> a")
()
}
}
}