Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in python if/else chains #1183

Merged
merged 1 commit into from Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 14 additions & 1 deletion core/src/main/scala/org/bykn/bosatsu/Matchless.scala
Expand Up @@ -173,7 +173,20 @@ object Matchless {
l.nonEmpty || hasSideEffect(b)
}

case class If(cond: BoolExpr, thenExpr: Expr, elseExpr: Expr) extends Expr
case class If(cond: BoolExpr, thenExpr: Expr, elseExpr: Expr) extends Expr {
def flatten: (NonEmptyList[(BoolExpr, Expr)], Expr) = {
def combine(expr: Expr): (List[(BoolExpr, Expr)], Expr) =
expr match {
case If(c1, t1, e1) =>
val (ifs, e2) = combine(e1)
(((c1, t1)) :: ifs, e2)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously this code was in PythonGen, but it really only depends on Matchless.Expr. The order was being reversed by appending to the tail: ifs :+ ((c1, t1)) which for some branches changes semantics (where Matchless has leveraged prior information to do less work).

case last => (Nil, last)
}

val (rest, last) = combine(elseExpr)
(NonEmptyList((cond, thenExpr), rest), last)
}
}
case class Always(cond: BoolExpr, thenExpr: Expr) extends Expr
def always(cond: BoolExpr, thenExpr: Expr): Expr =
if (hasSideEffect(cond)) Always(cond, thenExpr)
Expand Down
Expand Up @@ -1878,17 +1878,8 @@ object PythonGen {
// there is no need to
loop(in, slotName)
case Literal(lit) => Env.pure(Code.litToExpr(lit))
case If(cond, thenExpr, elseExpr) =>
def combine(expr: Expr): (List[(BoolExpr, Expr)], Expr) =
expr match {
case If(c1, t1, e1) =>
val (ifs, e2) = combine(e1)
(ifs :+ ((c1, t1)), e2)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note the order is wrong here: we are pushing c1, t1 on to the end of ifs but that came from expanding e1 so it should come before e1 since it was before originally.

case last => (Nil, last)
}

val (rest, last) = combine(elseExpr)
val ifs = NonEmptyList((cond, thenExpr), rest)
case ifExpr @ If(_, _, _) =>
val (ifs, last) = ifExpr.flatten

val ifsV = ifs.traverse { case (c, t) =>
(boolExpr(c, slotName), loop(t, slotName)).tupled
Expand Down
24 changes: 23 additions & 1 deletion core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala
Expand Up @@ -16,7 +16,7 @@ import org.scalatest.funsuite.AnyFunSuite
class MatchlessTest extends AnyFunSuite {
implicit val generatorDrivenConfig: PropertyCheckConfiguration =
PropertyCheckConfiguration(minSuccessful =
if (Platform.isScalaJvm) 1000 else 20
if (Platform.isScalaJvm) 5000 else 20
)

type Fn = (PackageName, Constructor) => Option[DataRepr]
Expand Down Expand Up @@ -59,6 +59,11 @@ class MatchlessTest extends AnyFunSuite {
}
}

val genMatchlessExpr: Gen[Matchless.Expr] =
genInputs.map { case (b, r, t, fn) =>
Matchless.fromLet(b, r, t)(fn)
}

test("regressions") {
// this is illegal code, but it shouldn't throw a match error:
val name = Identifier.Name("foo")
Expand Down Expand Up @@ -151,4 +156,21 @@ class MatchlessTest extends AnyFunSuite {
assert(matchlessRes == matchRes)
}
}

test("If.flatten can be unflattened") {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test didn't always find the bug (since randomly finding a case is somewhat rare, but after 3 runs, one found it, so I think it would be flakey enough that we would discover the issue).

forAll(genMatchlessExpr) {
case ifexpr @ Matchless.If(_, _, _) =>
val (chain, rest) = ifexpr.flatten
def unflatten(ifs: NonEmptyList[(Matchless.BoolExpr, Matchless.Expr)], elseX: Matchless.Expr): Matchless.If =
ifs.tail match {
case Nil => Matchless.If(ifs.head._1, ifs.head._2, elseX)
case head :: next =>
val end = unflatten(NonEmptyList(head, next), elseX)
Matchless.If(ifs.head._1, ifs.head._2, end)
}

assert(unflatten(chain, rest) == ifexpr)
case _ => ()
}
}
}
8 changes: 4 additions & 4 deletions test_workspace/BinNat.bosatsu
Expand Up @@ -45,10 +45,6 @@ def toBinNat(n: Int) -> BinNat:

def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison:
recur a:
case Zero:
match b:
case Odd(_) | Even(_): LT
case Zero: EQ
case Odd(a1):
match b:
case Odd(b1): cmp_BinNat(a1, b1)
Expand All @@ -71,6 +67,10 @@ def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison:
case GT | EQ: GT
case LT: LT
case Zero: GT
case Zero:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is more efficient since Zero is rare, but triggered the bug before this fix.

match b:
case Odd(_) | Even(_): LT
case Zero: EQ

# this is more efficient potentially than cmp_BinNat
# because at the first difference we can stop. In the worst
Expand Down