From 0c0d723b99bb22662b4cdd1e32e622d2862836f1 Mon Sep 17 00:00:00 2001 From: Patrick Oscar Boykin Date: Sat, 23 Mar 2024 08:43:02 -1000 Subject: [PATCH] Fix bug in python if/else chains --- .../scala/org/bykn/bosatsu/Matchless.scala | 15 +++++++++++- .../bosatsu/codegen/python/PythonGen.scala | 13 ++-------- .../org/bykn/bosatsu/MatchlessTests.scala | 24 ++++++++++++++++++- test_workspace/BinNat.bosatsu | 8 +++---- 4 files changed, 43 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 4ae773ba..167f333c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -173,7 +173,20 @@ object Matchless { l.nonEmpty || hasSideEffect(b) } - case class If(cond: BoolExpr, thenExpr: Expr, elseExpr: Expr) extends Expr + case class If(cond: BoolExpr, thenExpr: Expr, elseExpr: Expr) extends Expr { + def flatten: (NonEmptyList[(BoolExpr, Expr)], Expr) = { + def combine(expr: Expr): (List[(BoolExpr, Expr)], Expr) = + expr match { + case If(c1, t1, e1) => + val (ifs, e2) = combine(e1) + (((c1, t1)) :: ifs, e2) + case last => (Nil, last) + } + + val (rest, last) = combine(elseExpr) + (NonEmptyList((cond, thenExpr), rest), last) + } + } case class Always(cond: BoolExpr, thenExpr: Expr) extends Expr def always(cond: BoolExpr, thenExpr: Expr): Expr = if (hasSideEffect(cond)) Always(cond, thenExpr) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index 3290e441..8b10131f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -1878,17 +1878,8 @@ object PythonGen { // there is no need to loop(in, slotName) case Literal(lit) => Env.pure(Code.litToExpr(lit)) - case If(cond, thenExpr, elseExpr) => - def combine(expr: Expr): (List[(BoolExpr, Expr)], Expr) = - expr match { - case If(c1, t1, e1) => - val (ifs, e2) = combine(e1) - (ifs :+ ((c1, t1)), e2) - case last => (Nil, last) - } - - val (rest, last) = combine(elseExpr) - val ifs = NonEmptyList((cond, thenExpr), rest) + case ifExpr @ If(_, _, _) => + val (ifs, last) = ifExpr.flatten val ifsV = ifs.traverse { case (c, t) => (boolExpr(c, slotName), loop(t, slotName)).tupled diff --git a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala index 359dc41a..284512b5 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala @@ -16,7 +16,7 @@ import org.scalatest.funsuite.AnyFunSuite class MatchlessTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = - if (Platform.isScalaJvm) 1000 else 20 + if (Platform.isScalaJvm) 5000 else 20 ) type Fn = (PackageName, Constructor) => Option[DataRepr] @@ -59,6 +59,11 @@ class MatchlessTest extends AnyFunSuite { } } + val genMatchlessExpr: Gen[Matchless.Expr] = + genInputs.map { case (b, r, t, fn) => + Matchless.fromLet(b, r, t)(fn) + } + test("regressions") { // this is illegal code, but it shouldn't throw a match error: val name = Identifier.Name("foo") @@ -151,4 +156,21 @@ class MatchlessTest extends AnyFunSuite { assert(matchlessRes == matchRes) } } + + test("If.flatten can be unflattened") { + forAll(genMatchlessExpr) { + case ifexpr @ Matchless.If(_, _, _) => + val (chain, rest) = ifexpr.flatten + def unflatten(ifs: NonEmptyList[(Matchless.BoolExpr, Matchless.Expr)], elseX: Matchless.Expr): Matchless.If = + ifs.tail match { + case Nil => Matchless.If(ifs.head._1, ifs.head._2, elseX) + case head :: next => + val end = unflatten(NonEmptyList(head, next), elseX) + Matchless.If(ifs.head._1, ifs.head._2, end) + } + + assert(unflatten(chain, rest) == ifexpr) + case _ => () + } + } } diff --git a/test_workspace/BinNat.bosatsu b/test_workspace/BinNat.bosatsu index f1243afa..49d5676f 100644 --- a/test_workspace/BinNat.bosatsu +++ b/test_workspace/BinNat.bosatsu @@ -45,10 +45,6 @@ def toBinNat(n: Int) -> BinNat: def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison: recur a: - case Zero: - match b: - case Odd(_) | Even(_): LT - case Zero: EQ case Odd(a1): match b: case Odd(b1): cmp_BinNat(a1, b1) @@ -71,6 +67,10 @@ def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison: case GT | EQ: GT case LT: LT case Zero: GT + case Zero: + match b: + case Odd(_) | Even(_): LT + case Zero: EQ # this is more efficient potentially than cmp_BinNat # because at the first difference we can stop. In the worst