Skip to content

Commit

Permalink
Add property checks for BinNat (#1176)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Mar 22, 2024
1 parent 0b42877 commit 5191197
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 31 deletions.
2 changes: 0 additions & 2 deletions core/src/test/scala/org/bykn/bosatsu/ShapeTest.scala
Expand Up @@ -3,8 +3,6 @@ package org.bykn.bosatsu
import org.bykn.bosatsu.rankn.TypeEnv
import org.scalatest.funsuite.AnyFunSuite

import cats.syntax.all._

class ShapeTest extends AnyFunSuite {

def makeTE(
Expand Down
Expand Up @@ -212,7 +212,7 @@ class RankNInferTest extends AnyFunSuite {

// this could be used to test the string representation of expressions
def checkTERepr(statement: String, repr: String) =
checkLast(statement)(te => assert(te.repr == repr))
checkLast(statement)(te => assert(te.repr.render(80) == repr))

/** Test that a program is ill-typed
*/
Expand Down
3 changes: 1 addition & 2 deletions test_workspace/AvlTree.bosatsu
Expand Up @@ -210,8 +210,7 @@ contains_test = (
])
)

def eq_i(a, b):
cmp_Int(a, b) matches EQ
eq_i = eq_Int

def add_increases_size(t, i, msg):
s0 = size(t)
Expand Down
148 changes: 132 additions & 16 deletions test_workspace/BinNat.bosatsu
Expand Up @@ -2,7 +2,8 @@ package Bosatsu/BinNat

from Bosatsu/Nat import Nat, Zero as NatZero, Succ as NatSucc, times2 as times2_Nat

export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev
export (BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2,
prev, times_BinNat, exp, cmp_BinNat, is_even, sub_BinNat, sub_Option, eq_BinNat)
# a natural number with three variants:
# Zero = 0
# Odd(n) = 2n + 1
Expand All @@ -11,6 +12,9 @@ export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev
# Zero, Odd(Zero), Even(Zero), Odd(Odd(Zero)), Even(Odd(Zero))
enum BinNat: Zero, Odd(half: BinNat), Even(half1: BinNat)

def is_even(b: BinNat) -> Bool:
b matches Zero | Even(_)

# Convert a BinNat into the equivalent Int
# this is O(log(b)) operation
def toInt(b: BinNat) -> Int:
Expand All @@ -37,71 +41,168 @@ def toBinNat(n: Int) -> BinNat:
(dec(n), fns)
)
# Now apply all the transformations
fns.foldLeft(Zero, \n, fn -> fn(n))
fns.foldLeft(Zero, (n, fn) -> fn(n))

def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison:
recur a:
case Zero:
match b:
case Odd(_) | Even(_): LT
case Zero: EQ
case Odd(a1):
match b:
case Odd(b1): cmp_BinNat(a1, b1)
case Even(b1):
# 2n + 1 <> 2m + 2
# if n <= m, LT
# if n > m GT
match cmp_BinNat(a1, b1):
case LT | EQ: LT
case GT: GT
case Zero: GT
case Even(a1):
match b:
case Even(b1): cmp_BinNat(a1, b1)
case Odd(b1):
# 2n + 2 <> 2m + 1
# if n >= m, GT
# if n < m LT
match cmp_BinNat(a1, b1):
case GT | EQ: GT
case LT: LT
case Zero: GT

# this is more efficient potentially than cmp_BinNat
# because at the first difference we can stop. In the worst
# case of equality, the cost is the same.
def eq_BinNat(a: BinNat, b: BinNat) -> Bool:
recur a:
case Zero: b matches Zero
case Odd(n):
match b:
case Odd(m): eq_BinNat(n, m)
case _: False
case Even(n):
match b:
case Even(m): eq_BinNat(n, m)
case _: False

# Return the next number
def next(b: BinNat) -> BinNat:
recur b:
Zero: Odd(Zero)
Odd(half):
# (2n + 1) + 1 = 2(n + 1)
Even(half)
Even(half1):
# 2(n + 1) + 1
Odd(next(half1))
Zero: Odd(Zero)

# Return the previous number if the number is > 0, else return 0
def prev(b: BinNat) -> BinNat:
recur b:
Zero: Zero
Odd(Zero):
# This breaks the law below because 0 - 1 = 0 in this function
Zero
Odd(half):
case Zero | Odd(Zero): Zero
case Odd(half):
# (2n + 1) - 1 = 2n = 2(n-1 + 1)
Even(prev(half))
Even(half1):
case Even(half1):
# 2(n + 1) - 1 = 2n + 1
Odd(half1)

def add_BinNat(left: BinNat, right: BinNat) -> BinNat:
recur left:
Zero: right
Odd(left) as odd:
match right:
Zero: odd
Odd(right):
# 2left + 1 + 2right + 1 = 2((left + right) + 1)
Even(add_BinNat(left, right))
Even(right):
# 2left + 1 + 2(right + 1) = 2((left + right) + 1) + 1
Odd(add_BinNat(left, right.next()))
Zero: odd
Even(left) as even:
match right:
Zero: even
Odd(right):
# 2(left + 1) + 2right + 1 = 2((left + right) + 1) + 1
Odd(add_BinNat(left, right.next()))
Even(right):
# 2(left + 1) + 2(right + 1) = 2((left + right + 1) + 1)
Even(add_BinNat(left, right.next()))
Zero: even
Zero: right

# multiply by 2
def times2(b: BinNat) -> BinNat:
recur b:
Zero: Zero
Odd(n):
#2(2n + 1) = Even(2n)
Even(times2(n))
Even(n):
#2(2(n + 1)) = 2((2n + 1) + 1)
Even(Odd(n))
Zero: Zero

# 2n - 1 if it is defined
def doub_prev(b: BinNat) -> Option[BinNat]:
match b:
case Odd(n):
# 2(2n + 1) - 1 = 4n + 1 = Odd(2n)
Some(Odd(times2(n)))
case Even(n):
# 2(2n + 2) - 1 = 4n + 3 = 2(2n + 1) + 1
Some(Odd(Odd(n)))
case Zero: None

def sub_Option(left: BinNat, right: BinNat) -> Option[BinNat]:
recur left:
case Zero:
match right:
case Zero: Some(Zero)
case _: None
case Odd(left) as odd:
match right:
case Zero: Some(odd)
case Odd(right):
# (2n + 1) - (2m + 1) = 2(n - m)
match sub_Option(left, right):
case Some(n_m): Some(times2(n_m))
case None: None
case Even(right):
# (2n + 1) - (2m + 2) = 2(n - m) - 1
# note if (2n + 1) > (2m + 2), then n > m
match sub_Option(left, right):
case Some(n_m): doub_prev(n_m)
case None: None
case Even(left) as even:
match right:
case Zero: Some(even)
case Odd(right):
# Even can't equal odd, so we never return
# zero. Next an even - odd is odd.
# (2n + 2) - (2m + 1) = 2(n - m) + 1
match sub_Option(left, right):
case Some(n_m): Some(Odd(n_m))
case None: None
case Even(right):
# (2n + 2) - (2m + 2) = 2(n - m)
match sub_Option(left, right):
case Some(n_m): Some(times2(n_m))
case None: None

def sub_BinNat(left: BinNat, right: BinNat) -> BinNat:
match sub_Option(left, right):
case Some(v): v
case None: Zero

def div2(b: BinNat) -> BinNat:
match b:
case Zero: Zero
case Odd(n): n
case Even(n): prev(n)
case Odd(n):
# (2n + 1)/2 = n
n
case Even(n):
# (2n + 2)/2 = n + 1
next(n)

# multiply two BinNat together
def times_BinNat(left: BinNat, right: BinNat) -> BinNat:
Expand All @@ -122,6 +223,21 @@ def times_BinNat(left: BinNat, right: BinNat) -> BinNat:
prod = times_BinNat(left, right)
times2(prod.add_BinNat(right))

one = Odd(Zero)

def exp(base: BinNat, power: BinNat) -> BinNat:
recur power:
case Zero: one
case Odd(n):
# b^(2n + 1) == (b^n) * (b^n) * b
bn = exp(base, n)
bn.times_BinNat(bn).times_BinNat(base)
case Even(n):
# b^(2n + 2) = (b^n * b)^2
bn = exp(base, n)
bn1 = bn.times_BinNat(base)
bn1.times_BinNat(bn1)

# fold(fn, a, Zero) = a
# fold(fn, a, n) = fold(fn, fn(a, n - 1), n - 1)
def fold_left_BinNat(fn: (a, BinNat) -> a, init: a, cnt: BinNat) -> a:
Expand Down Expand Up @@ -158,7 +274,6 @@ def next_law(i, msg):
def times2_law(i, msg):
Assertion(i.toBinNat().times2().toInt().eq_Int(i.times(2)), msg)

one = Odd(Zero)
two = one.next()
three = two.next()
four = three.next()
Expand Down Expand Up @@ -210,4 +325,5 @@ test = TestSuite(
Assertion(fib(two).toInt().eq_Int(2), "fib(2) == 2"),
Assertion(fib(three).toInt().eq_Int(3), "fib(3) == 3"),
Assertion(fib(four).toInt().eq_Int(5), "fib(4) == 5"),
Assertion(cmp_BinNat(54.toBinNat(), 54.toBinNat()) matches EQ, "54 == 54"),
])
2 changes: 1 addition & 1 deletion test_workspace/Nat.bosatsu
Expand Up @@ -122,7 +122,7 @@ n4 = Succ(n3)
n5 = Succ(n4)

def operator ==(i0: Int, i1: Int):
cmp_Int(i0, i1) matches EQ
eq_Int(i0, i1)

def addLaw(n1: Nat, n2: Nat, label: String) -> Test:
Assertion(add(n1, n2).to_Int() == (n1.to_Int() + n2.to_Int()), label)
Expand Down

0 comments on commit 5191197

Please sign in to comment.