Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add property checks for BinNat #1176

Merged
merged 6 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions test_workspace/AvlTree.bosatsu
Expand Up @@ -210,8 +210,7 @@ contains_test = (
])
)

def eq_i(a, b):
cmp_Int(a, b) matches EQ
eq_i = eq_Int

def add_increases_size(t, i, msg):
s0 = size(t)
Expand Down
121 changes: 110 additions & 11 deletions test_workspace/BinNat.bosatsu
Expand Up @@ -2,7 +2,8 @@ package Bosatsu/BinNat

from Bosatsu/Nat import Nat, Zero as NatZero, Succ as NatSucc, times2 as times2_Nat

export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev
export (BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2,
prev, times_BinNat, exp, cmp_BinNat, is_even, sub_BinNat, eq_BinNat)
# a natural number with three variants:
# Zero = 0
# Odd(n) = 2n + 1
Expand All @@ -11,6 +12,9 @@ export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev
# Zero, Odd(Zero), Even(Zero), Odd(Odd(Zero)), Even(Odd(Zero))
enum BinNat: Zero, Odd(half: BinNat), Even(half1: BinNat)

def is_even(b: BinNat) -> Bool:
b matches Zero | Even(_)

# Convert a BinNat into the equivalent Int
# this is O(log(b)) operation
def toInt(b: BinNat) -> Int:
Expand All @@ -37,7 +41,51 @@ def toBinNat(n: Int) -> BinNat:
(dec(n), fns)
)
# Now apply all the transformations
fns.foldLeft(Zero, \n, fn -> fn(n))
fns.foldLeft(Zero, (n, fn) -> fn(n))

def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison:
recur a:
case Zero:
match b:
case Odd(_) | Even(_): LT
case Zero: EQ
case Odd(a1):
match b:
case Odd(b1): cmp_BinNat(a1, b1)
case Even(b1):
# 2n + 1 <> 2m + 2
# if n <= m, LT
# if n > m GT
match cmp_BinNat(a1, b1):
case LT | EQ: LT
case GT: GT
case Zero: GT
case Even(a1):
match b:
case Even(b1): cmp_BinNat(a1, b1)
case Odd(b1):
# 2n + 2 <> 2m + 1
# if n >= m, GT
# if n < m LT
match cmp_BinNat(a1, b1):
case GT | EQ: GT
case LT: LT
case Zero: GT

# this is more efficient potentially than cmp_BinNat
# because at the first difference we can stop. In the worst
# case of equality, the cost is the same.
def eq_BinNat(a: BinNat, b: BinNat) -> Bool:
recur a:
case Zero: b matches Zero
case Odd(n):
match b:
case Odd(m): eq_BinNat(n, m)
case _: False
case Even(n):
match b:
case Even(m): eq_BinNat(n, m)
case _: False

# Return the next number
def next(b: BinNat) -> BinNat:
Expand All @@ -53,14 +101,11 @@ def next(b: BinNat) -> BinNat:
# Return the previous number if the number is > 0, else return 0
def prev(b: BinNat) -> BinNat:
recur b:
Zero: Zero
Odd(Zero):
# This breaks the law below because 0 - 1 = 0 in this function
Zero
Odd(half):
case Zero | Odd(Zero): Zero
case Odd(half):
# (2n + 1) - 1 = 2n = 2(n-1 + 1)
Even(prev(half))
Even(half1):
case Even(half1):
# 2(n + 1) - 1 = 2n + 1
Odd(half1)

Expand Down Expand Up @@ -97,11 +142,51 @@ def times2(b: BinNat) -> BinNat:
#2(2(n + 1)) = 2((2n + 1) + 1)
Even(Odd(n))

def sub_BinNat(left: BinNat, right: BinNat) -> BinNat:
# invariant: left > right
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see below: invariant: left >= right, so the comment below is wrong, we can return Zero.

# we always return >= 1
def loop(left, right):
recur left:
case Odd(left) as odd:
match right:
case Zero: odd
case Odd(right):
# (2n + 1) - (2m + 1) = 2(n - m)
times2(loop(left, right))
case Even(right):
# (2n + 1) - (2m + 2) = 2(n - m) - 1
# note if (2n + 1) > (2m + 2), then n > m
times2(loop(left, right)).prev()
case Even(left) as even:
match right:
case Zero: even
case Odd(right):
# (2n + 2) - (2m + 1)
# if n >= m: 2(n - m) + 1
# if (2n + 2) > (2m + 1), then n > m
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is wrong... if (2n + 2) > (2m + 1) then n >= m. So, that means the invariant must actually be left >= right. So that means we only know that (2n + 2) >= (2m + 1), but that's okay, because we know an odd number can't equal an even number, so we recover n >= m.

Odd(loop(left, right))
case Even(right):
# (2n + 2) - (2m + 2) = 2(n - m)
times2(loop(left, right))
case Zero:
# this is last because it is actually unreachable
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't true. The invariant is left >= right

# due to the invariant that left > right
# but the type system can't express this
# we put it last to improve performance
Zero

if cmp_BinNat(left, right) matches (LT | EQ): Zero
else: loop(left, right)

def div2(b: BinNat) -> BinNat:
match b:
case Zero: Zero
case Odd(n): n
case Even(n): prev(n)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a bug! yay for property checks!

case Odd(n):
# (2n + 1)/2 = n
n
case Even(n):
# (2n + 2)/2 = n + 1
next(n)

# multiply two BinNat together
def times_BinNat(left: BinNat, right: BinNat) -> BinNat:
Expand All @@ -122,6 +207,21 @@ def times_BinNat(left: BinNat, right: BinNat) -> BinNat:
prod = times_BinNat(left, right)
times2(prod.add_BinNat(right))

one = Odd(Zero)

def exp(base: BinNat, power: BinNat) -> BinNat:
recur power:
case Zero: one
case Odd(n):
# b^(2n + 1) == (b^n) * (b^n) * b
bn = exp(base, n)
bn.times_BinNat(bn).times_BinNat(base)
case Even(n):
# b^(2n + 2) = (b^n * b)^2
bn = exp(base, n)
bn1 = bn.times_BinNat(base)
bn1.times_BinNat(bn1)

# fold(fn, a, Zero) = a
# fold(fn, a, n) = fold(fn, fn(a, n - 1), n - 1)
def fold_left_BinNat(fn: (a, BinNat) -> a, init: a, cnt: BinNat) -> a:
Expand Down Expand Up @@ -158,7 +258,6 @@ def next_law(i, msg):
def times2_law(i, msg):
Assertion(i.toBinNat().times2().toInt().eq_Int(i.times(2)), msg)

one = Odd(Zero)
two = one.next()
three = two.next()
four = three.next()
Expand Down
2 changes: 1 addition & 1 deletion test_workspace/Nat.bosatsu
Expand Up @@ -122,7 +122,7 @@ n4 = Succ(n3)
n5 = Succ(n4)

def operator ==(i0: Int, i1: Int):
cmp_Int(i0, i1) matches EQ
eq_Int(i0, i1)

def addLaw(n1: Nat, n2: Nat, label: String) -> Test:
Assertion(add(n1, n2).to_Int() == (n1.to_Int() + n2.to_Int()), label)
Expand Down
76 changes: 68 additions & 8 deletions test_workspace/NumberProps.bosatsu
@@ -1,6 +1,10 @@
package Bosatsu/NumberProps

from Bosatsu/BinNat import (BinNat, toBinNat as int_to_BinNat)
from Bosatsu/BinNat import (BinNat, toBinNat as int_to_BinNat, is_even as is_even_BinNat,
times2 as times2_BinNat, div2 as div2_BinNat, Zero as BNZero, Even as BNEven,
times_BinNat, exp as exp_BinNat, cmp_BinNat, toInt as binNat_to_Int, add_BinNat,
next as next_BinNat, sub_BinNat, eq_BinNat,
)
from Bosatsu/Nat import (Nat, Zero as NZero, Succ as NSucc, to_Nat as int_to_Nat, is_even as is_even_Nat,
times2 as times2_Nat, div2 as div2_Nat, cmp_Nat, to_Int as nat_to_Int, add as add_Nat,
mult as mult_Nat, exp as exp_Nat, sub_Nat)
Expand All @@ -17,8 +21,6 @@ rand_Int: Rand[Int] = from_pair(int_range(128), geometric_Int)
rand_Nat: Rand[Nat] = rand_Int.map_Rand(int_to_Nat)
rand_BinNat: Rand[BinNat] = rand_Int.map_Rand(int_to_BinNat)

eq_Int = (a, b) -> a.cmp_Int(b) matches EQ

int_props = suite_Prop(
"Int props",
[
Expand Down Expand Up @@ -51,6 +53,7 @@ def exp_Int(base: Int, power: Int) -> Int:
int_loop(power, 1, (p, acc) -> (p.sub(1), acc.times(base)))

small_rand_Nat: Rand[Nat] = int_range(7).map_Rand(int_to_Nat)
small_rand_BinNat: Rand[BinNat] = int_range(7).map_Rand(int_to_BinNat)

nat_props = suite_Prop(
"Nat props",
Expand All @@ -72,7 +75,7 @@ nat_props = suite_Prop(
forall_Prop(prod_Rand(rand_Nat, rand_Nat), "add homomorphism", ((n1, n2)) -> (
n3 = add_Nat(n1, n2)
i3 = add(n1.nat_to_Int(), n2.nat_to_Int())
Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "add homomorphism")
Assertion(eq_Int(n3.nat_to_Int(), i3), "add homomorphism")
)),
forall_Prop(prod_Rand(rand_Nat, rand_Nat), "sub_Nat homomorphism", ((n1, n2)) -> (
n3 = sub_Nat(n1, n2)
Expand All @@ -81,19 +84,19 @@ nat_props = suite_Prop(
match cmp_Int(i1, i2):
case EQ | GT:
i3 = sub(i1, i2)
Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "sub_Nat homomorphism")
Assertion(eq_Int(n3.nat_to_Int(), i3), "sub_Nat homomorphism")
case LT:
Assertion(n3 matches NZero, "sub to zero")
)),
forall_Prop(prod_Rand(rand_Nat, rand_Nat), "mult homomorphism", ((n1, n2)) -> (
n3 = mult_Nat(n1, n2)
i3 = times(n1.nat_to_Int(), n2.nat_to_Int())
Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "mult homomorphism")
Assertion(eq_Int(n3.nat_to_Int(), i3), "mult homomorphism")
)),
forall_Prop(prod_Rand(small_rand_Nat, small_rand_Nat), "exp homomorphism", ((n1, n2)) -> (
n3 = exp_Nat(n1, n2)
i3 = exp_Int(n1.nat_to_Int(), n2.nat_to_Int())
Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "exp homomorphism")
Assertion(eq_Int(n3.nat_to_Int(), i3), "exp homomorphism")
)),
forall_Prop(rand_Nat, "times2 == x -> mult(2, x)", n -> (
t2 = n.times2_Nat()
Expand All @@ -103,7 +106,64 @@ nat_props = suite_Prop(
]
)

all_props = [int_props, nat_props]
binnat_props = suite_Prop(
"BinNat props",
[
forall_Prop(rand_BinNat, "if is_even(n) then times2(div2(n)) == n", n -> (
if is_even_BinNat(n):
n1 = times2_BinNat(div2_BinNat(n))
Assertion(cmp_BinNat(n1, n) matches EQ, "times2/div2")
else:
# we return the previous number
n1 = times2_BinNat(div2_BinNat(n))
Assertion(cmp_BinNat(n1.next_BinNat(), n) matches EQ, "times2/div2")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "cmp_BinNat matches cmp_Int", ((n1, n2)) -> (
cmp_n = cmp_BinNat(n1, n2)
cmp_i = cmp_Int(n1.binNat_to_Int(), n2.binNat_to_Int())
Assertion(cmp_Comparison(cmp_n, cmp_i) matches EQ, "cmp_BinNat")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "cmp_BinNat matches eq_BinNat", ((n1, n2)) -> (
eq1 = cmp_BinNat(n1, n2) matches EQ
eq2 = eq_BinNat(n1, n2)
correct = (eq1, eq2) matches (True, True) | (False, False)
Assertion(correct, "cmp vs eq consistency")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "add homomorphism", ((n1, n2)) -> (
n3 = add_BinNat(n1, n2)
i3 = add(n1.binNat_to_Int(), n2.binNat_to_Int())
Assertion(eq_Int(n3.binNat_to_Int(), i3), "add homomorphism")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "sub_BinNat homomorphism", ((n1, n2)) -> (
n3 = sub_BinNat(n1, n2)
i1 = n1.binNat_to_Int()
i2 = n2.binNat_to_Int()
match cmp_Int(i1, i2):
case EQ | GT:
i3 = sub(i1, i2)
Assertion(eq_Int(n3.binNat_to_Int(), i3), "sub_BinNat homomorphism")
case LT:
Assertion(n3 matches BNZero, "sub to zero")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "mult homomorphism", ((n1, n2)) -> (
n3 = times_BinNat(n1, n2)
i3 = times(n1.binNat_to_Int(), n2.binNat_to_Int())
Assertion(eq_Int(n3.binNat_to_Int(), i3), "mult homomorphism")
)),
forall_Prop(prod_Rand(small_rand_BinNat, small_rand_BinNat), "exp homomorphism", ((n1, n2)) -> (
n3 = exp_BinNat(n1, n2)
i3 = exp_Int(n1.binNat_to_Int(), n2.binNat_to_Int())
Assertion(eq_Int(n3.binNat_to_Int(), i3), "exp homomorphism")
)),
forall_Prop(rand_BinNat, "times2 == x -> mult(2, x)", n -> (
t2 = n.times2_BinNat()
t2_2 = times_BinNat(n, BNEven(BNZero))
Assertion(cmp_BinNat(t2, t2_2) matches EQ, "times2 == mult(2, _)")
)),
]
)

all_props = [int_props, nat_props, binnat_props]

seed = 123456
test = TestSuite("properties", [
Expand Down
2 changes: 1 addition & 1 deletion test_workspace/Properties.bosatsu
Expand Up @@ -33,7 +33,7 @@ def run_Prop(prop: Prop, trials: Int, seed: Int) -> Test:

signed64 = int_range(1 << 64).map_Rand(i -> i - (1 << 63))

def operator ==(a, b): cmp_Int(a, b) matches EQ
def operator ==(a, b): eq_Int(a, b)

not_law = forall_Prop(
signed64,
Expand Down