Skip to content

Commit

Permalink
Improve external def parsing (#1188)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Mar 28, 2024
1 parent a472c8d commit 18efb84
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 34 deletions.
Expand Up @@ -109,7 +109,7 @@ object DefRecursionCheck {
case Def(defn) =>
// make this the same shape as a in declaration
checkDef(TopLevel, defn.copy(result = (defn.result, ())))
case ExternalDef(_, _, _) =>
case ExternalDef(_, _, _, _) =>
unitValid
}
case _ => unitValid
Expand Down
16 changes: 15 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/Package.scala
Expand Up @@ -273,6 +273,11 @@ object Package {
Type.Const.Defined(p, TypeName(tds.name)) -> tds.region
}.toMap

lazy val extDefRegions: Map[Identifier.Bindable, Region] =
stmts.iterator.collect { case ed: Statement.ExternalDef =>
ed.name -> ed.region
}.toMap

optProg.flatMap {
case Program((importedTypeEnv, parsedTypeEnv), lets, extDefs, _) =>
val inferVarianceParsed
Expand Down Expand Up @@ -336,8 +341,17 @@ object Package {
errs.map(PackageError.TotalityCheckError(p, _))
}

val theseExternals =
parsedTypeEnv
.externalDefs
.collect { case (pack, b, t) if pack === p =>
// by construction this has to have all the regions
(b, (t, extDefRegions(b)))
}
.toMap

val inferenceEither = Infer
.typeCheckLets(p, lets)
.typeCheckLets(p, lets, theseExternals)
.runFully(
withFQN,
Referant.typeConstructors(imps) ++ typeEnv.typeConstructors,
Expand Down
14 changes: 14 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/PackageError.scala
Expand Up @@ -363,6 +363,20 @@ object PackageError {
) + Doc.hardLine +
context

(doc, Some(region))
case Infer.Error.KindExpectedType(tpe, kind, region) =>
val tmap = showTypes(pack, tpe :: Nil)
val context =
lm.showRegion(region, 2, errColor)
.getOrElse(
Doc.str(region)
) // we should highlight the whole region
val doc = Doc.text("expected type ") +
tmap(tpe) + Doc.text(
" to have kind *, which is to say be a valid value, but it is kind "
) + Kind.toDoc(kind) + Doc.hardLine +
context

(doc, Some(region))
case Infer.Error.KindInvalidApply(applied, leftK, rightK, region) =>
val leftT = applied.on
Expand Down
63 changes: 49 additions & 14 deletions core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala
Expand Up @@ -157,7 +157,7 @@ final class SourceConverter(
SourceConverter.InvalidDefTypeParameters(
args,
freeVarsList,
ds,
Right(ds),
region
),
gen
Expand Down Expand Up @@ -1227,7 +1227,7 @@ final class SourceConverter(
values: List[Statement.ValueStatement]
): Result[Unit] = {
val extDefNames =
values.collect { case ed @ Statement.ExternalDef(name, _, _) =>
values.collect { case ed @ Statement.ExternalDef(name, _, _, _) =>
(name, ed.region)
}

Expand Down Expand Up @@ -1259,7 +1259,7 @@ final class SourceConverter(
s match {
case b @ Statement.Bind(_) => Some(Left(b))
case d @ Statement.Def(_) => Some(Right(d))
case Statement.ExternalDef(_, _, _) => None
case Statement.ExternalDef(_, _, _, _) => None
}

def checkDefBind(s: Statement.ValueStatement): Result[Unit] =
Expand Down Expand Up @@ -1391,7 +1391,7 @@ final class SourceConverter(
stmts.toList.flatMap {
case d @ Def(_) =>
(d.defstatement.name, RecursionKind.Recursive, Left(d)) :: Nil
case ExternalDef(_, _, _) =>
case ExternalDef(_, _, _, _) =>
// we don't allow external defs to shadow at all, so skip it here
Nil
case Bind(BindingStatement(bound, decl, _)) =>
Expand Down Expand Up @@ -1456,7 +1456,7 @@ final class SourceConverter(
}

val withEx: List[Either[ExternalDef, Flattened]] =
stmts.collect { case e @ ExternalDef(_, _, _) => Left(e) }.toList :::
stmts.collect { case e @ ExternalDef(_, _, _, _) => Left(e) }.toList :::
flatIn.map {
case (b, _, Left(d @ Def(dstmt))) =>
Right(Left(Def(dstmt.copy(name = b))(d.region)))
Expand Down Expand Up @@ -1513,7 +1513,7 @@ final class SourceConverter(
(boundName, rec, l1) :: Nil
}
(topBound1, r)
case Left(ExternalDef(n, _, _)) =>
case Left(ExternalDef(n, _, _, _)) =>
(topBound + n, success(Nil))
}
}(SourceConverter.parallelIor)).map(_.flatten)
Expand All @@ -1526,7 +1526,7 @@ final class SourceConverter(
], List[Statement]]] = {
val stmts = Statement.valuesOf(ss).toList
stmts
.collect { case ed @ Statement.ExternalDef(name, params, result) =>
.collect { case ed @ Statement.ExternalDef(name, ta, params, result) =>
(
params.traverse(p => toType(p._2, ed.region)),
toType(result, ed.region)
Expand All @@ -1547,7 +1547,7 @@ final class SourceConverter(
}
}
}
.map { (tpe: rankn.Type) =>
.flatMap { (tpe: rankn.Type) =>
val freeVars = rankn.Type.freeTyVars(tpe :: Nil)
// these vars were parsed so they are never skolem vars
val freeBound = freeVars.map {
Expand All @@ -1557,10 +1557,34 @@ final class SourceConverter(
sys.error(s"invariant violation: parsed a skolem var: $s")
// $COVERAGE-ON$
}
// TODO: Kind support parsing kinds
val maybeForAll =
rankn.Type.forAll(freeBound.map(n => (n, Kind.Type)), tpe)
(name, maybeForAll)
val finalTpe = ta match {
case None =>
success(rankn.Type.forAll(freeBound.map(n => (n, Kind.Type)), tpe))
case Some(frees0) =>
val frees = frees0.map { case (ref, optK) => ref.toBoundVar -> optK }
if (frees.iterator.map(_._1).toSet === freeBound.toSet[rankn.Type.Var.Bound]) {
success(rankn.Type.forAll(frees.map {
case (v, None) => (v, Kind.Type)
case (v, Some(k)) => (v, k)
}, tpe))
}
else {
val kindMap = frees.iterator.collect { case (v, Some(k)) => (v, k) }.toMap
val vs = freeBound.map { v => (v, kindMap.getOrElse(v, Kind.Type)) }
val t = rankn.Type.forAll(vs, tpe)
SourceConverter.partial(
SourceConverter.InvalidDefTypeParameters(
frees0,
freeBound,
Left(ed),
ed.region
),
t
)
}
}

finalTpe.map(name -> _)
}
}
// TODO: we could implement Iterable[Ior[A, B]] => Ior[A, Iterble[B]]
Expand Down Expand Up @@ -1887,10 +1911,21 @@ object SourceConverter {
final case class InvalidDefTypeParameters[B](
declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind])],
free: List[Type.Var.Bound],
defstmt: DefStatement[Pattern.Parsed, B],
defstmt: Either[Statement.ExternalDef, DefStatement[Pattern.Parsed, B]],
region: Region
) extends Error {

def name: Identifier.Bindable = defstmt match {
case Right(ds) => ds.name
case Left(ed) => ed.name
}

def expectation: String = defstmt match {
case Right(_) => "a subset of"
case Left(_) => "the same as"
}


def message = {
def tstr(l: List[Type.Var.Bound]): String =
l.iterator.map(_.name).mkString("[", ", ", "]")
Expand All @@ -1903,7 +1938,7 @@ object SourceConverter {
.renderTrim(80)

val freeStr = tstr(free)
s"${defstmt.name.asString} found declared types: $decl, not a subset of $freeStr"
s"${name.asString} found declared types: $decl, not $expectation $freeStr"
}
}

Expand Down
37 changes: 25 additions & 12 deletions core/src/main/scala/org/bykn/bosatsu/Statement.scala
Expand Up @@ -34,8 +34,8 @@ sealed abstract class Statement {
Struct(nm, typeArgs, args)(r)
case Enum(nm, typeArgs, parts) =>
Enum(nm, typeArgs, parts)(r)
case ExternalDef(name, args, res) =>
ExternalDef(name, args, res)(r)
case ExternalDef(name, ta, args, res) =>
ExternalDef(name, ta, args, res)(r)
case ExternalStruct(nm, targs) =>
ExternalStruct(nm, targs)(r)
}
Expand Down Expand Up @@ -83,7 +83,7 @@ object Statement {
case Bind(BindingStatement(bound, _, _)) =>
bound.names // TODO Keep identifiers
case Def(defstatement) => defstatement.name :: Nil
case ExternalDef(name, _, _) => name :: Nil
case ExternalDef(name, _, _, _) => name :: Nil
}

/** These are all the free bindable names in the right hand side of this
Expand All @@ -98,7 +98,7 @@ object Statement {
(innerFrees - defstatement.name) -- defstatement.args.toList.flatMap(
_.patternNames
)
case ExternalDef(_, _, _) => SortedSet.empty
case ExternalDef(_, _, _, _) => SortedSet.empty
}

/** These are all the bindings, free or not, in this Statement
Expand All @@ -109,7 +109,7 @@ object Statement {
case Def(defstatement) =>
(defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList
.flatMap(_.patternNames)
case ExternalDef(name, _, _) => SortedSet(name)
case ExternalDef(name, _, _, _) => SortedSet(name)
}
}

Expand All @@ -126,6 +126,7 @@ object Statement {
extends ValueStatement
case class ExternalDef(
name: Bindable,
typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind])]],
params: List[(Bindable, TypeRef)],
result: TypeRef
)(val region: Region)
Expand Down Expand Up @@ -230,6 +231,10 @@ object Statement {

val externalDef = {

val kindAnnot: P[Kind] =
(maybeSpace.soft.with1 *> (P.char(':') *> maybeSpace *> Kind.parser))
val typeParams = TypeRef.typeParams(kindAnnot.?).?

val args =
P.char('(') *> maybeSpace *> argParser.nonEmptyList <* maybeSpace <* P
.char(')')
Expand All @@ -239,16 +244,16 @@ object Statement {

(((keySpace(
"def"
) *> Identifier.bindableParser ~ args ~ result).region) <* toEOL)
.map { case (region, ((name, args), resType)) =>
ExternalDef(name, args.toList, resType)(region)
) *> Identifier.bindableParser ~ typeParams ~ args ~ result).region) <* toEOL)
.map { case (region, (((name, tps), args), resType)) =>
ExternalDef(name, tps, args.toList, resType)(region)
}
}

val externalVal =
(argParser <* toEOL).region
.map { case (region, (name, resType)) =>
ExternalDef(name, Nil, resType)(region)
ExternalDef(name, None, Nil, resType)(region)
}

keySpace("external") *> P.oneOf(
Expand Down Expand Up @@ -385,11 +390,19 @@ object Statement {
.char(':') +
colonSep +
indentedCons + Doc.line
case ExternalDef(name, Nil, res) =>
case ExternalDef(name, None, Nil, res) =>
Doc.text("external ") + Document[Bindable].document(name) + Doc.text(
": "
) + res.toDoc + Doc.line
case ExternalDef(name, args, res) =>
case ExternalDef(name, tps, args, res) =>
val taDoc = tps match {
case None => Doc.empty
case Some(ta) =>
TypeRef.docTypeArgs(ta.toList) {
case None => Doc.empty
case Some(k) => colonSpace + Kind.toDoc(k)
}
}
val argDoc = {
val da = Doc.intercalate(
Doc.text(", "),
Expand All @@ -401,7 +414,7 @@ object Statement {
}
Doc.text("external def ") + Document[Bindable].document(
name
) + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line
) + taDoc + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line
case ExternalStruct(nm, typeArgs) =>
val taDoc =
TypeRef.docTypeArgs(typeArgs.toList) {
Expand Down
Expand Up @@ -21,7 +21,7 @@ object TypeRefConverter {
import TypeRef._

t match {
case TypeVar(v) => Applicative[F].pure(TyVar(Type.Var.Bound(v)))
case tv @ TypeVar(_) => Applicative[F].pure(TyVar(tv.toBoundVar))
case TypeName(n) => nameToType(n.ident).map(TyConst(_))
case TypeArrow(as, b) =>
(as.traverse(toType(_)), toType(b)).mapN(Fun(_, _))
Expand Down
25 changes: 23 additions & 2 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Expand Up @@ -182,6 +182,11 @@ object Infer {
rightK: Kind,
region: Region
) extends TypeError
case class KindExpectedType(
tpe: Type,
kind: Kind.Cons,
region: Region
) extends TypeError
case class KindMismatch(
target: Type,
targetKind: Kind,
Expand Down Expand Up @@ -2614,7 +2619,8 @@ object Infer {
*/
def typeCheckLets[A: HasRegion](
pack: PackageName,
ls: List[(Bindable, RecursionKind, Expr[A])]
ls: List[(Bindable, RecursionKind, Expr[A])],
externals: Map[Bindable, (Type, Region)]
): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = {
// Group together lets that don't include each other to get more type errors
// if we can
Expand Down Expand Up @@ -2655,7 +2661,22 @@ object Infer {
else Some(bs :+ item)
}

run(groups)
val checkExternals =
GetEnv.flatMap { env =>
externals
.toList
.sortBy { case (_, (_, region)) => region }
.parTraverse_ { case (_, (t, region)) =>
env.getKind(t, region) match {
case Right(Kind.Type) => unit
case Right(cons @ Kind.Cons(_, _)) =>
fail(Error.KindExpectedType(t, cons, region))
case Left(err) => fail(err)
}
}
}

run(groups).parProductL(checkExternals)
}

/** This is useful to testing purposes.
Expand Down
17 changes: 17 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Expand Up @@ -3990,4 +3990,21 @@ test = TestSuite("bases",
12
)
}

test("external defs with explicit type parameters exactly match") {
val testCode = """
package ErrorCheck
external def foo[b](lst: List[a]) -> a
"""
evalFail(List(testCode)) {
case kie @ PackageError.SourceConverterErrorsIn(_, _, _) =>
val message = kie.message(Map.empty, Colorize.None)
assert(message.contains("Region(30,59)"))
assert(message.contains("[b], not the same as [a]"))
assert(testCode.substring(30, 59) == "def foo[b](lst: List[a]) -> a")
()
}
}
}

0 comments on commit 18efb84

Please sign in to comment.